toot-relay.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. package main
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/x509"
  6. "encoding/base64"
  7. "encoding/binary"
  8. "flag"
  9. "fmt"
  10. "net/http"
  11. "os"
  12. "strconv"
  13. "strings"
  14. "time"
  15. "github.com/sideshow/apns2"
  16. "github.com/sideshow/apns2/certificate"
  17. "github.com/sideshow/apns2/payload"
  18. log "github.com/sirupsen/logrus"
  19. "golang.org/x/net/http2"
  20. httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http"
  21. dd_logrus "gopkg.in/DataDog/dd-trace-go.v1/contrib/sirupsen/logrus"
  22. "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
  23. )
  24. type Message struct {
  25. isProduction bool
  26. notification *apns2.Notification
  27. requestLog *log.Entry // For logging with datadog context
  28. }
  29. var (
  30. developmentClient *apns2.Client
  31. productionClient *apns2.Client
  32. topic string
  33. messageChan chan *Message
  34. maxQueueSize int
  35. maxWorkers int
  36. ctx context.Context
  37. )
  38. func worker(workerId int) {
  39. log.Info(fmt.Sprintf("starting worker %d", workerId))
  40. defer log.Info(fmt.Sprintf("stopping worker %d", workerId))
  41. var client *apns2.Client
  42. for msg := range messageChan {
  43. if msg.isProduction {
  44. client = productionClient
  45. } else {
  46. client = developmentClient
  47. }
  48. res, err := client.Push(msg.notification)
  49. if err != nil {
  50. msg.requestLog.Error(fmt.Sprintf("Push error: %s", err))
  51. continue
  52. }
  53. if res.Sent() {
  54. msg.requestLog.WithFields(log.Fields{
  55. "status-code": res.StatusCode,
  56. "apns-id": res.ApnsID,
  57. "reason": res.Reason,
  58. "device-token": msg.notification.DeviceToken,
  59. "expiration": msg.notification.Expiration,
  60. "priority": msg.notification.Priority,
  61. "collapse-id": msg.notification.CollapseID,
  62. }).Info(fmt.Sprintf("Sent notification (%v)", res.StatusCode))
  63. } else {
  64. msg.requestLog.WithFields(log.Fields{
  65. "status-code": res.StatusCode,
  66. "apns-id": res.ApnsID,
  67. "reason": res.Reason,
  68. }).Error(fmt.Sprintf("Failed to send notification (%v)", res.StatusCode))
  69. }
  70. }
  71. }
  72. func main() {
  73. tracer.Start()
  74. defer tracer.Stop()
  75. mux := httptrace.NewServeMux()
  76. log.AddHook(&dd_logrus.DDContextLogHook{})
  77. ctx = context.Background()
  78. flag.IntVar(&maxQueueSize, "max-queue-size", 4096, "Maximum number of messages to queue")
  79. flag.IntVar(&maxWorkers, "max-workers", 8, "Maximum number of workers")
  80. flag.Parse()
  81. topic = env("TOPIC", "cx.c3.toot")
  82. p12file := env("P12_FILENAME", "toot-relay.p12")
  83. p12base64 := env("P12_BASE64", "")
  84. p12password := env("P12_PASSWORD", "")
  85. port := env("PORT", "42069")
  86. tlsCrtFile := env("CRT_FILENAME", "toot-relay.crt")
  87. tlsKeyFile := env("KEY_FILENAME", "toot-relay.key")
  88. // CA_FILENAME can be set to a file that contains PEM encoded certificates that will be
  89. // used as the sole root CAs when connecting to the Apple Notification Service API.
  90. // If unset, the system-wide certificate store will be used.
  91. caFile := env("CA_FILENAME", "")
  92. var rootCAs *x509.CertPool
  93. if caPEM, err := os.ReadFile(caFile); err == nil {
  94. rootCAs = x509.NewCertPool()
  95. if ok := rootCAs.AppendCertsFromPEM(caPEM); !ok {
  96. log.Fatal(fmt.Sprintf("CA file %s specified but no CA certificates could be loaded\n", caFile))
  97. }
  98. }
  99. if p12base64 != "" {
  100. bytes, err := base64.StdEncoding.DecodeString(p12base64)
  101. if err != nil {
  102. log.Fatal(fmt.Sprintf("Base64 decoding error: %s", err))
  103. }
  104. cert, err := certificate.FromP12Bytes(bytes, p12password)
  105. if err != nil {
  106. log.Fatal(fmt.Sprintf("Error parsing certificate: %s", err))
  107. }
  108. developmentClient = apns2.NewClient(cert).Development()
  109. productionClient = apns2.NewClient(cert).Production()
  110. } else {
  111. cert, err := certificate.FromP12File(p12file, p12password)
  112. if err != nil {
  113. log.Fatal(fmt.Sprintf("Error loading certificate file: %s", err))
  114. }
  115. developmentClient = apns2.NewClient(cert).Development()
  116. productionClient = apns2.NewClient(cert).Production()
  117. }
  118. if rootCAs != nil {
  119. developmentClient.HTTPClient.Transport.(*http2.Transport).TLSClientConfig.RootCAs = rootCAs
  120. productionClient.HTTPClient.Transport.(*http2.Transport).TLSClientConfig.RootCAs = rootCAs
  121. }
  122. mux.HandleFunc("/relay-to/", handler)
  123. messageChan = make(chan *Message, maxQueueSize)
  124. for i := 1; i <= maxWorkers; i++ {
  125. go worker(i)
  126. }
  127. if _, err := os.Stat(tlsCrtFile); !os.IsNotExist(err) {
  128. log.Fatal(http.ListenAndServeTLS(":"+port, tlsCrtFile, tlsKeyFile, mux))
  129. } else {
  130. log.Fatal(http.ListenAndServe(":"+port, mux))
  131. }
  132. }
  133. func handler(writer http.ResponseWriter, request *http.Request) {
  134. span, sctx := tracer.StartSpanFromContext(ctx, "web.request", tracer.ResourceName(request.RequestURI))
  135. defer span.Finish()
  136. requestLog := log.WithContext(sctx)
  137. components := strings.Split(request.URL.Path, "/")
  138. if len(components) < 4 {
  139. writer.WriteHeader(500)
  140. fmt.Fprintln(writer, "Invalid URL path:", request.URL.Path)
  141. requestLog.Error(fmt.Sprintf("Invalid URL path: %s", request.URL.Path))
  142. return
  143. }
  144. isProduction := components[2] == "production"
  145. notification := &apns2.Notification{}
  146. notification.DeviceToken = components[3]
  147. buffer := new(bytes.Buffer)
  148. buffer.ReadFrom(request.Body)
  149. encodedString := encode85(buffer.Bytes())
  150. payload := payload.NewPayload().Alert("🎺").MutableContent().ContentAvailable().Custom("p", encodedString)
  151. if len(components) > 4 {
  152. payload.Custom("x", strings.Join(components[4:], "/"))
  153. }
  154. notification.Payload = payload
  155. notification.Topic = topic
  156. switch request.Header.Get("Content-Encoding") {
  157. case "aesgcm":
  158. if publicKey, err := encodedValue(request.Header, "Crypto-Key", "dh"); err == nil {
  159. payload.Custom("k", publicKey)
  160. } else {
  161. writer.WriteHeader(500)
  162. fmt.Fprintln(writer, "Error retrieving public key:", err)
  163. requestLog.Error(fmt.Sprintf("Error retrieving public key: %s", err))
  164. return
  165. }
  166. if salt, err := encodedValue(request.Header, "Encryption", "salt"); err == nil {
  167. payload.Custom("s", salt)
  168. } else {
  169. writer.WriteHeader(500)
  170. fmt.Fprintln(writer, "Error retrieving salt:", err)
  171. requestLog.Error(fmt.Sprintf("Error retrieving salt: %s", err))
  172. return
  173. }
  174. //case "aes128gcm": // No further headers needed. However, not implemented on client side so return 415.
  175. default:
  176. writer.WriteHeader(415)
  177. fmt.Fprintln(writer, "Unsupported Content-Encoding:", request.Header.Get("Content-Encoding"))
  178. requestLog.Error(fmt.Sprintf("Unsupported Content-Encoding: %s", request.Header.Get("Content-Encoding")))
  179. return
  180. }
  181. if seconds := request.Header.Get("TTL"); seconds != "" {
  182. if ttl, err := strconv.Atoi(seconds); err == nil {
  183. notification.Expiration = time.Now().Add(time.Duration(ttl) * time.Second)
  184. }
  185. }
  186. if topic := request.Header.Get("Topic"); topic != "" {
  187. notification.CollapseID = topic
  188. }
  189. switch request.Header.Get("Urgency") {
  190. case "very-low", "low":
  191. notification.Priority = apns2.PriorityLow
  192. default:
  193. notification.Priority = apns2.PriorityHigh
  194. }
  195. messageChan <- &Message{isProduction, notification, requestLog}
  196. // always reply w/ success, since we don't know how apple responded
  197. writer.WriteHeader(201)
  198. }
  199. func env(name, defaultValue string) string {
  200. if value, isPresent := os.LookupEnv(name); isPresent {
  201. return value
  202. } else {
  203. return defaultValue
  204. }
  205. }
  206. func encodedValue(header http.Header, name, key string) (string, error) {
  207. keyValues := parseKeyValues(header.Get(name))
  208. value, exists := keyValues[key]
  209. if !exists {
  210. return "", fmt.Errorf("value %s not found in header %s", key, name)
  211. }
  212. bytes, err := base64.RawURLEncoding.DecodeString(value)
  213. if err != nil {
  214. return "", err
  215. }
  216. return encode85(bytes), nil
  217. }
  218. func parseKeyValues(values string) map[string]string {
  219. f := func(c rune) bool {
  220. return c == ';'
  221. }
  222. entries := strings.FieldsFunc(values, f)
  223. m := make(map[string]string)
  224. for _, entry := range entries {
  225. parts := strings.Split(entry, "=")
  226. m[parts[0]] = parts[1]
  227. }
  228. return m
  229. }
  230. var z85digits = []byte("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.-:+=^!/*?&<>()[]{}@%$#")
  231. func encode85(bytes []byte) string {
  232. numBlocks := len(bytes) / 4
  233. suffixLength := len(bytes) % 4
  234. encodedLength := numBlocks * 5
  235. if suffixLength != 0 {
  236. encodedLength += suffixLength + 1
  237. }
  238. encodedBytes := make([]byte, encodedLength)
  239. src := bytes
  240. dest := encodedBytes
  241. for block := 0; block < numBlocks; block++ {
  242. value := binary.BigEndian.Uint32(src)
  243. for i := 0; i < 5; i++ {
  244. dest[4-i] = z85digits[value%85]
  245. value /= 85
  246. }
  247. src = src[4:]
  248. dest = dest[5:]
  249. }
  250. if suffixLength != 0 {
  251. value := 0
  252. for i := 0; i < suffixLength; i++ {
  253. value *= 256
  254. value |= int(src[i])
  255. }
  256. for i := 0; i < suffixLength+1; i++ {
  257. dest[suffixLength-i] = z85digits[value%85]
  258. value /= 85
  259. }
  260. }
  261. return string(encodedBytes)
  262. }