server.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. // Snowflake-specific websocket server plugin. It reports the transport name as
  2. // "snowflake".
  3. package main
  4. import (
  5. "crypto/tls"
  6. "errors"
  7. "flag"
  8. "fmt"
  9. "io"
  10. "io/ioutil"
  11. "log"
  12. "net"
  13. "net/http"
  14. "os"
  15. "os/signal"
  16. "path/filepath"
  17. "strings"
  18. "sync"
  19. "syscall"
  20. "time"
  21. "git.torproject.org/pluggable-transports/goptlib.git"
  22. "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
  23. "git.torproject.org/pluggable-transports/websocket.git/websocket"
  24. "golang.org/x/crypto/acme/autocert"
  25. "golang.org/x/net/http2"
  26. )
  27. const ptMethodName = "snowflake"
  28. const requestTimeout = 10 * time.Second
  29. const maxMessageSize = 64 * 1024
  30. // How long to wait for ListenAndServe or ListenAndServeTLS to return an error
  31. // before deciding that it's not going to return.
  32. const listenAndServeErrorTimeout = 100 * time.Millisecond
  33. var ptInfo pt.ServerInfo
  34. // When a connection handler starts, +1 is written to this channel; when it
  35. // ends, -1 is written.
  36. var handlerChan = make(chan int)
  37. func usage() {
  38. fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS]
  39. WebSocket server pluggable transport for Snowflake. Works only as a managed
  40. proxy. Uses TLS with ACME (Let's Encrypt) by default. Set the certificate
  41. hostnames with the --acme-hostnames option. Use ServerTransportListenAddr in
  42. torrc to choose the listening port. When using TLS, this program will open an
  43. additional HTTP listener on port 80 to work with ACME.
  44. `, os.Args[0])
  45. flag.PrintDefaults()
  46. }
  47. // An abstraction that makes an underlying WebSocket connection look like an
  48. // io.ReadWriteCloser.
  49. type webSocketConn struct {
  50. Ws *websocket.WebSocket
  51. messageBuf []byte
  52. }
  53. // Implements io.Reader.
  54. func (conn *webSocketConn) Read(b []byte) (n int, err error) {
  55. for len(conn.messageBuf) == 0 {
  56. var m websocket.Message
  57. m, err = conn.Ws.ReadMessage()
  58. if err != nil {
  59. return
  60. }
  61. if m.Opcode == 8 {
  62. err = io.EOF
  63. return
  64. }
  65. if m.Opcode != 2 {
  66. err = errors.New(fmt.Sprintf("got non-binary opcode %d", m.Opcode))
  67. return
  68. }
  69. conn.messageBuf = m.Payload
  70. }
  71. n = copy(b, conn.messageBuf)
  72. conn.messageBuf = conn.messageBuf[n:]
  73. return
  74. }
  75. // Implements io.Writer.
  76. func (conn *webSocketConn) Write(b []byte) (int, error) {
  77. err := conn.Ws.WriteMessage(2, b)
  78. return len(b), err
  79. }
  80. // Implements io.Closer.
  81. func (conn *webSocketConn) Close() error {
  82. // Ignore any error in trying to write a Close frame.
  83. _ = conn.Ws.WriteFrame(8, nil)
  84. return conn.Ws.Conn.Close()
  85. }
  86. // Create a new webSocketConn.
  87. func newWebSocketConn(ws *websocket.WebSocket) webSocketConn {
  88. var conn webSocketConn
  89. conn.Ws = ws
  90. return conn
  91. }
  92. // Copy from WebSocket to socket and vice versa.
  93. func proxy(local *net.TCPConn, conn *webSocketConn) {
  94. var wg sync.WaitGroup
  95. wg.Add(2)
  96. go func() {
  97. _, err := io.Copy(conn, local)
  98. if err != nil {
  99. log.Printf("error copying ORPort to WebSocket")
  100. }
  101. local.CloseRead()
  102. conn.Close()
  103. wg.Done()
  104. }()
  105. go func() {
  106. _, err := io.Copy(local, conn)
  107. if err != nil {
  108. log.Printf("error copying WebSocket to ORPort")
  109. }
  110. local.CloseWrite()
  111. conn.Close()
  112. wg.Done()
  113. }()
  114. wg.Wait()
  115. }
  116. // Return an address string suitable to pass into pt.DialOr.
  117. func clientAddr(clientIPParam string) string {
  118. if clientIPParam == "" {
  119. return ""
  120. }
  121. // Check if client addr is a valid IP
  122. clientIP := net.ParseIP(clientIPParam)
  123. if clientIP == nil {
  124. return ""
  125. }
  126. // Add a dummy port number. USERADDR requires a port number.
  127. return (&net.TCPAddr{IP: clientIP, Port: 1, Zone: ""}).String()
  128. }
  129. func webSocketHandler(ws *websocket.WebSocket) {
  130. // Undo timeouts on HTTP request handling.
  131. ws.Conn.SetDeadline(time.Time{})
  132. conn := newWebSocketConn(ws)
  133. defer conn.Close()
  134. handlerChan <- 1
  135. defer func() {
  136. handlerChan <- -1
  137. }()
  138. // Pass the address of client as the remote address of incoming connection
  139. clientIPParam := ws.Request().URL.Query().Get("client_ip")
  140. addr := clientAddr(clientIPParam)
  141. if addr == "" {
  142. statsChannel <- false
  143. } else {
  144. statsChannel <- true
  145. }
  146. or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
  147. if err != nil {
  148. log.Printf("failed to connect to ORPort: %s", err)
  149. return
  150. }
  151. defer or.Close()
  152. proxy(or, &conn)
  153. }
  154. func initServer(addr *net.TCPAddr,
  155. getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error),
  156. listenAndServe func(*http.Server, chan<- error)) (*http.Server, error) {
  157. // We're not capable of listening on port 0 (i.e., an ephemeral port
  158. // unknown in advance). The reason is that while the net/http package
  159. // exposes ListenAndServe and ListenAndServeTLS, those functions never
  160. // return, so there's no opportunity to find out what the port number
  161. // is, in between the Listen and Serve steps.
  162. // https://groups.google.com/d/msg/Golang-nuts/3F1VRCCENp8/3hcayZiwYM8J
  163. if addr.Port == 0 {
  164. return nil, fmt.Errorf("cannot listen on port %d; configure a port using ServerTransportListenAddr", addr.Port)
  165. }
  166. var config websocket.Config
  167. config.MaxMessageSize = maxMessageSize
  168. server := &http.Server{
  169. Addr: addr.String(),
  170. Handler: config.Handler(webSocketHandler),
  171. ReadTimeout: requestTimeout,
  172. }
  173. // We need to override server.TLSConfig.GetCertificate--but first
  174. // server.TLSConfig needs to be non-nil. If we just create our own new
  175. // &tls.Config, it will lack the default settings that the net/http
  176. // package sets up for things like HTTP/2. Therefore we first call
  177. // http2.ConfigureServer for its side effect of initializing
  178. // server.TLSConfig properly. An alternative would be to make a dummy
  179. // net.Listener, call Serve on it, and let it return.
  180. // https://github.com/golang/go/issues/16588#issuecomment-237386446
  181. err := http2.ConfigureServer(server, nil)
  182. if err != nil {
  183. return server, err
  184. }
  185. server.TLSConfig.GetCertificate = getCertificate
  186. // Another unfortunate effect of the inseparable net/http ListenAndServe
  187. // is that we can't check for Listen errors like "permission denied" and
  188. // "address already in use" without potentially entering the infinite
  189. // loop of Serve. The hack we apply here is to wait a short time,
  190. // listenAndServeErrorTimeout, to see if an error is returned (because
  191. // it's better if the error message goes to the tor log through
  192. // SMETHOD-ERROR than if it only goes to the snowflake log).
  193. errChan := make(chan error)
  194. go listenAndServe(server, errChan)
  195. select {
  196. case err = <-errChan:
  197. break
  198. case <-time.After(listenAndServeErrorTimeout):
  199. break
  200. }
  201. return server, err
  202. }
  203. func startServer(addr *net.TCPAddr) (*http.Server, error) {
  204. return initServer(addr, nil, func(server *http.Server, errChan chan<- error) {
  205. log.Printf("listening with plain HTTP on %s", addr)
  206. err := server.ListenAndServe()
  207. if err != nil {
  208. log.Printf("error in ListenAndServe: %s", err)
  209. }
  210. errChan <- err
  211. })
  212. }
  213. func startServerTLS(addr *net.TCPAddr, getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error)) (*http.Server, error) {
  214. return initServer(addr, getCertificate, func(server *http.Server, errChan chan<- error) {
  215. log.Printf("listening with HTTPS on %s", addr)
  216. err := server.ListenAndServeTLS("", "")
  217. if err != nil {
  218. log.Printf("error in ListenAndServeTLS: %s", err)
  219. }
  220. errChan <- err
  221. })
  222. }
  223. func getCertificateCacheDir() (string, error) {
  224. stateDir, err := pt.MakeStateDir()
  225. if err != nil {
  226. return "", err
  227. }
  228. return filepath.Join(stateDir, "snowflake-certificate-cache"), nil
  229. }
  230. func main() {
  231. var acmeEmail string
  232. var acmeHostnamesCommas string
  233. var disableTLS bool
  234. var logFilename string
  235. flag.Usage = usage
  236. flag.StringVar(&acmeEmail, "acme-email", "", "optional contact email for Let's Encrypt notifications")
  237. flag.StringVar(&acmeHostnamesCommas, "acme-hostnames", "", "comma-separated hostnames for TLS certificate")
  238. flag.BoolVar(&disableTLS, "disable-tls", false, "don't use HTTPS")
  239. flag.StringVar(&logFilename, "log", "", "log file to write to")
  240. flag.Parse()
  241. log.SetFlags(log.LstdFlags | log.LUTC)
  242. var logOutput io.Writer = os.Stderr
  243. if logFilename != "" {
  244. f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
  245. if err != nil {
  246. log.Fatalf("can't open log file: %s", err)
  247. }
  248. defer f.Close()
  249. logOutput = f
  250. }
  251. //We want to send the log output through our scrubber first
  252. log.SetOutput(&safelog.LogScrubber{Output: logOutput})
  253. if !disableTLS && acmeHostnamesCommas == "" {
  254. log.Fatal("the --acme-hostnames option is required")
  255. }
  256. acmeHostnames := strings.Split(acmeHostnamesCommas, ",")
  257. log.Printf("starting")
  258. var err error
  259. ptInfo, err = pt.ServerSetup(nil)
  260. if err != nil {
  261. log.Fatalf("error in setup: %s", err)
  262. }
  263. go statsThread()
  264. var certManager *autocert.Manager
  265. if !disableTLS {
  266. log.Printf("ACME hostnames: %q", acmeHostnames)
  267. var cache autocert.Cache
  268. cacheDir, err := getCertificateCacheDir()
  269. if err == nil {
  270. log.Printf("caching ACME certificates in directory %q", cacheDir)
  271. cache = autocert.DirCache(cacheDir)
  272. } else {
  273. log.Printf("disabling ACME certificate cache: %s", err)
  274. }
  275. certManager = &autocert.Manager{
  276. Prompt: autocert.AcceptTOS,
  277. HostPolicy: autocert.HostWhitelist(acmeHostnames...),
  278. Email: acmeEmail,
  279. Cache: cache,
  280. }
  281. }
  282. // The ACME HTTP-01 responder only works when it is running on port 80.
  283. // We actually open the port in the loop below, so that any errors can
  284. // be reported in the SMETHOD-ERROR of some bindaddr.
  285. // https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#http-challenge
  286. needHTTP01Listener := !disableTLS
  287. servers := make([]*http.Server, 0)
  288. for _, bindaddr := range ptInfo.Bindaddrs {
  289. if bindaddr.MethodName != ptMethodName {
  290. pt.SmethodError(bindaddr.MethodName, "no such method")
  291. continue
  292. }
  293. if needHTTP01Listener {
  294. addr := *bindaddr.Addr
  295. addr.Port = 80
  296. log.Printf("Starting HTTP-01 ACME listener")
  297. lnHTTP01, err := net.ListenTCP("tcp", &addr)
  298. if err != nil {
  299. log.Printf("error opening HTTP-01 ACME listener: %s", err)
  300. pt.SmethodError(bindaddr.MethodName, "HTTP-01 ACME listener: "+err.Error())
  301. continue
  302. }
  303. server := &http.Server{
  304. Addr: addr.String(),
  305. Handler: certManager.HTTPHandler(nil),
  306. }
  307. go func() {
  308. log.Fatal(server.Serve(lnHTTP01))
  309. }()
  310. servers = append(servers, server)
  311. needHTTP01Listener = false
  312. }
  313. var server *http.Server
  314. args := pt.Args{}
  315. if disableTLS {
  316. args.Add("tls", "no")
  317. server, err = startServer(bindaddr.Addr)
  318. } else {
  319. args.Add("tls", "yes")
  320. for _, hostname := range acmeHostnames {
  321. args.Add("hostname", hostname)
  322. }
  323. server, err = startServerTLS(bindaddr.Addr, certManager.GetCertificate)
  324. }
  325. if err != nil {
  326. log.Printf("error opening listener: %s", err)
  327. pt.SmethodError(bindaddr.MethodName, err.Error())
  328. continue
  329. }
  330. pt.SmethodArgs(bindaddr.MethodName, bindaddr.Addr, args)
  331. servers = append(servers, server)
  332. }
  333. pt.SmethodsDone()
  334. var numHandlers int = 0
  335. var sig os.Signal
  336. sigChan := make(chan os.Signal, 1)
  337. signal.Notify(sigChan, syscall.SIGTERM)
  338. if os.Getenv("TOR_PT_EXIT_ON_STDIN_CLOSE") == "1" {
  339. // This environment variable means we should treat EOF on stdin
  340. // just like SIGTERM: https://bugs.torproject.org/15435.
  341. go func() {
  342. io.Copy(ioutil.Discard, os.Stdin)
  343. log.Printf("synthesizing SIGTERM because of stdin close")
  344. sigChan <- syscall.SIGTERM
  345. }()
  346. }
  347. // keep track of handlers and wait for a signal
  348. sig = nil
  349. for sig == nil {
  350. select {
  351. case n := <-handlerChan:
  352. numHandlers += n
  353. case sig = <-sigChan:
  354. }
  355. }
  356. // signal received, shut down
  357. log.Printf("caught signal %q, exiting", sig)
  358. for _, server := range servers {
  359. server.Close()
  360. }
  361. for numHandlers > 0 {
  362. numHandlers += <-handlerChan
  363. }
  364. }