server.go 7.2 KB


  1. // Snowflake-specific websocket server plugin. It reports the transport name as
  2. // "snowflake".
  3. package main
  4. import (
  5. "flag"
  6. "fmt"
  7. "io"
  8. "io/ioutil"
  9. "log"
  10. "net"
  11. "net/http"
  12. "os"
  13. "os/signal"
  14. "path/filepath"
  15. "strings"
  16. "sync"
  17. "syscall"
  18. "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
  19. "golang.org/x/crypto/acme/autocert"
  20. pt "git.torproject.org/pluggable-transports/goptlib.git"
  21. sf "git.torproject.org/pluggable-transports/snowflake.git/server/lib"
  22. )
  23. const ptMethodName = "snowflake"
  24. var ptInfo pt.ServerInfo
  25. func usage() {
  26. fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS]
  27. WebSocket server pluggable transport for Snowflake. Works only as a managed
  28. proxy. Uses TLS with ACME (Let's Encrypt) by default. Set the certificate
  29. hostnames with the --acme-hostnames option. Use ServerTransportListenAddr in
  30. torrc to choose the listening port. When using TLS, this program will open an
  31. additional HTTP listener on port 80 to work with ACME.
  32. `, os.Args[0])
  33. flag.PrintDefaults()
  34. }
  35. // Copy from one stream to another.
  36. func proxy(local *net.TCPConn, conn net.Conn) {
  37. var wg sync.WaitGroup
  38. wg.Add(2)
  39. go func() {
  40. if _, err := io.Copy(conn, local); err != nil && err != io.ErrClosedPipe {
  41. log.Printf("error copying ORPort to WebSocket %v", err)
  42. }
  43. local.CloseRead()
  44. conn.Close()
  45. wg.Done()
  46. }()
  47. go func() {
  48. if _, err := io.Copy(local, conn); err != nil && err != io.ErrClosedPipe {
  49. log.Printf("error copying WebSocket to ORPort %v", err)
  50. }
  51. local.CloseWrite()
  52. conn.Close()
  53. wg.Done()
  54. }()
  55. wg.Wait()
  56. }
  57. func acceptLoop(ln net.Listener) {
  58. for {
  59. conn, err := ln.Accept()
  60. if err != nil {
  61. if err, ok := err.(net.Error); ok && err.Temporary() {
  62. continue
  63. }
  64. log.Printf("Snowflake accept error: %s", err)
  65. break
  66. }
  67. defer conn.Close()
  68. addr := conn.RemoteAddr().String()
  69. statsChannel <- addr != ""
  70. or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
  71. if err != nil {
  72. log.Printf("failed to connect to ORPort: %s", err)
  73. continue
  74. }
  75. defer or.Close()
  76. go proxy(or, conn)
  77. }
  78. }
  79. func getCertificateCacheDir() (string, error) {
  80. stateDir, err := pt.MakeStateDir()
  81. if err != nil {
  82. return "", err
  83. }
  84. return filepath.Join(stateDir, "snowflake-certificate-cache"), nil
  85. }
  86. func main() {
  87. var acmeEmail string
  88. var acmeHostnamesCommas string
  89. var disableTLS bool
  90. var logFilename string
  91. var unsafeLogging bool
  92. flag.Usage = usage
  93. flag.StringVar(&acmeEmail, "acme-email", "", "optional contact email for Let's Encrypt notifications")
  94. flag.StringVar(&acmeHostnamesCommas, "acme-hostnames", "", "comma-separated hostnames for TLS certificate")
  95. flag.BoolVar(&disableTLS, "disable-tls", false, "don't use HTTPS")
  96. flag.StringVar(&logFilename, "log", "", "log file to write to")
  97. flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed")
  98. flag.Parse()
  99. log.SetFlags(log.LstdFlags | log.LUTC)
  100. var logOutput io.Writer = os.Stderr
  101. if logFilename != "" {
  102. f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
  103. if err != nil {
  104. log.Fatalf("can't open log file: %s", err)
  105. }
  106. defer f.Close()
  107. logOutput = f
  108. }
  109. if unsafeLogging {
  110. log.SetOutput(logOutput)
  111. } else {
  112. // We want to send the log output through our scrubber first
  113. log.SetOutput(&safelog.LogScrubber{Output: logOutput})
  114. }
  115. if !disableTLS && acmeHostnamesCommas == "" {
  116. log.Fatal("the --acme-hostnames option is required")
  117. }
  118. acmeHostnames := strings.Split(acmeHostnamesCommas, ",")
  119. log.Printf("starting")
  120. var err error
  121. ptInfo, err = pt.ServerSetup(nil)
  122. if err != nil {
  123. log.Fatalf("error in setup: %s", err)
  124. }
  125. go statsThread()
  126. var certManager *autocert.Manager
  127. if !disableTLS {
  128. log.Printf("ACME hostnames: %q", acmeHostnames)
  129. var cache autocert.Cache
  130. var cacheDir string
  131. cacheDir, err = getCertificateCacheDir()
  132. if err == nil {
  133. log.Printf("caching ACME certificates in directory %q", cacheDir)
  134. cache = autocert.DirCache(cacheDir)
  135. } else {
  136. log.Printf("disabling ACME certificate cache: %s", err)
  137. }
  138. certManager = &autocert.Manager{
  139. Prompt: autocert.AcceptTOS,
  140. HostPolicy: autocert.HostWhitelist(acmeHostnames...),
  141. Email: acmeEmail,
  142. Cache: cache,
  143. }
  144. }
  145. // The ACME HTTP-01 responder only works when it is running on port 80.
  146. // We actually open the port in the loop below, so that any errors can
  147. // be reported in the SMETHOD-ERROR of some bindaddr.
  148. // https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#http-challenge
  149. needHTTP01Listener := !disableTLS
  150. listeners := make([]net.Listener, 0)
  151. for _, bindaddr := range ptInfo.Bindaddrs {
  152. if bindaddr.MethodName != ptMethodName {
  153. pt.SmethodError(bindaddr.MethodName, "no such method")
  154. continue
  155. }
  156. if needHTTP01Listener {
  157. addr := *bindaddr.Addr
  158. addr.Port = 80
  159. log.Printf("Starting HTTP-01 ACME listener")
  160. var lnHTTP01 *net.TCPListener
  161. lnHTTP01, err = net.ListenTCP("tcp", &addr)
  162. if err != nil {
  163. log.Printf("error opening HTTP-01 ACME listener: %s", err)
  164. pt.SmethodError(bindaddr.MethodName, "HTTP-01 ACME listener: "+err.Error())
  165. continue
  166. }
  167. server := &http.Server{
  168. Addr: addr.String(),
  169. Handler: certManager.HTTPHandler(nil),
  170. }
  171. go func() {
  172. log.Fatal(server.Serve(lnHTTP01))
  173. }()
  174. listeners = append(listeners, lnHTTP01)
  175. needHTTP01Listener = false
  176. }
  177. // We're not capable of listening on port 0 (i.e., an ephemeral port
  178. // unknown in advance). The reason is that while the net/http package
  179. // exposes ListenAndServe and ListenAndServeTLS, those functions never
  180. // return, so there's no opportunity to find out what the port number
  181. // is, in between the Listen and Serve steps.
  182. // https://groups.google.com/d/msg/Golang-nuts/3F1VRCCENp8/3hcayZiwYM8J
  183. if bindaddr.Addr.Port == 0 {
  184. err := fmt.Errorf(
  185. "cannot listen on port %d; configure a port using ServerTransportListenAddr",
  186. bindaddr.Addr.Port)
  187. log.Printf("error opening listener: %s", err)
  188. pt.SmethodError(bindaddr.MethodName, err.Error())
  189. continue
  190. }
  191. var transport *sf.Transport
  192. args := pt.Args{}
  193. if disableTLS {
  194. args.Add("tls", "no")
  195. transport = sf.NewSnowflakeServer(nil)
  196. } else {
  197. args.Add("tls", "yes")
  198. for _, hostname := range acmeHostnames {
  199. args.Add("hostname", hostname)
  200. }
  201. transport = sf.NewSnowflakeServer(certManager.GetCertificate)
  202. }
  203. ln, err := transport.Listen(bindaddr.Addr)
  204. if err != nil {
  205. log.Printf("error opening listener: %s", err)
  206. pt.SmethodError(bindaddr.MethodName, err.Error())
  207. continue
  208. }
  209. defer ln.Close()
  210. go acceptLoop(ln)
  211. pt.SmethodArgs(bindaddr.MethodName, bindaddr.Addr, args)
  212. listeners = append(listeners, ln)
  213. }
  214. pt.SmethodsDone()
  215. sigChan := make(chan os.Signal, 1)
  216. signal.Notify(sigChan, syscall.SIGTERM)
  217. if os.Getenv("TOR_PT_EXIT_ON_STDIN_CLOSE") == "1" {
  218. // This environment variable means we should treat EOF on stdin
  219. // just like SIGTERM: https://bugs.torproject.org/15435.
  220. go func() {
  221. if _, err := io.Copy(ioutil.Discard, os.Stdin); err != nil {
  222. log.Printf("error copying os.Stdin to ioutil.Discard: %v", err)
  223. }
  224. log.Printf("synthesizing SIGTERM because of stdin close")
  225. sigChan <- syscall.SIGTERM
  226. }()
  227. }
  228. // Wait for a signal.
  229. sig := <-sigChan
  230. // Signal received, shut down.
  231. log.Printf("caught signal %q, exiting", sig)
  232. for _, ln := range listeners {
  233. ln.Close()
  234. }
  235. }