123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231 |
- package main
- import (
- "context"
- "log"
- "strconv"
- "strings"
- "time"
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/service/sqs"
- "github.com/aws/aws-sdk-go-v2/service/sqs/types"
- "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
- "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
- "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util"
- )
- const (
- cleanupThreshold = -2 * time.Minute
- )
- type sqsHandler struct {
- SQSClient sqsclient.SQSClient
- SQSQueueURL *string
- IPC *IPC
- cleanupInterval time.Duration
- }
- func (r *sqsHandler) pollMessages(ctx context.Context, chn chan<- *types.Message) {
- for {
- select {
- case <-ctx.Done():
- // if context is cancelled
- return
- default:
- res, err := r.SQSClient.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{
- QueueUrl: r.SQSQueueURL,
- MaxNumberOfMessages: 10,
- WaitTimeSeconds: 15,
- MessageAttributeNames: []string{
- string(types.QueueAttributeNameAll),
- },
- })
- if err != nil {
- log.Printf("SQSHandler: encountered error while polling for messages: %v\n", err)
- continue
- }
- for _, message := range res.Messages {
- chn <- &message
- }
- }
- }
- }
- func (r *sqsHandler) cleanupClientQueues(ctx context.Context) {
- for range time.NewTicker(r.cleanupInterval).C {
- // Runs at fixed intervals to clean up any client queues that were last changed more than 2 minutes ago
- select {
- case <-ctx.Done():
- // if context is cancelled
- return
- default:
- queueURLsList := []string{}
- var nextToken *string
- for {
- res, err := r.SQSClient.ListQueues(ctx, &sqs.ListQueuesInput{
- QueueNamePrefix: aws.String("snowflake-client-"),
- MaxResults: aws.Int32(1000),
- NextToken: nextToken,
- })
- if err != nil {
- log.Printf("SQSHandler: encountered error while retrieving client queues to clean up: %v\n", err)
- // client queues will be cleaned up the next time the cleanup operation is triggered automatically
- break
- }
- queueURLsList = append(queueURLsList, res.QueueUrls...)
- if res.NextToken == nil {
- break
- } else {
- nextToken = res.NextToken
- }
- }
- numDeleted := 0
- cleanupCutoff := time.Now().Add(cleanupThreshold)
- for _, queueURL := range queueURLsList {
- if !strings.Contains(queueURL, "snowflake-client-") {
- continue
- }
- res, err := r.SQSClient.GetQueueAttributes(ctx, &sqs.GetQueueAttributesInput{
- QueueUrl: aws.String(queueURL),
- AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
- })
- if err != nil {
- // According to the AWS SQS docs, the deletion process for a queue can take up to 60 seconds. So the queue
- // can be in the process of being deleted, but will still be returned by the ListQueues operation, but
- // fail when we try to GetQueueAttributes for the queue
- log.Printf("SQSHandler: encountered error while getting attribute of client queue %s. queue may already be deleted.\n", queueURL)
- continue
- }
- lastModifiedInt64, err := strconv.ParseInt(res.Attributes[string(types.QueueAttributeNameLastModifiedTimestamp)], 10, 64)
- if err != nil {
- log.Printf("SQSHandler: encountered invalid lastModifiedTimetamp value from client queue %s: %v\n", queueURL, err)
- continue
- }
- lastModified := time.Unix(lastModifiedInt64, 0)
- if lastModified.Before(cleanupCutoff) {
- _, err := r.SQSClient.DeleteQueue(ctx, &sqs.DeleteQueueInput{
- QueueUrl: aws.String(queueURL),
- })
- if err != nil {
- log.Printf("SQSHandler: encountered error when deleting client queue %s: %v\n", queueURL, err)
- continue
- } else {
- numDeleted += 1
- }
- }
- }
- log.Printf("SQSHandler: finished running iteration of client queue cleanup. found and deleted %d client queues.\n", numDeleted)
- }
- }
- }
- func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) {
- var encPollReq []byte
- var response []byte
- var err error
- clientID := message.MessageAttributes["ClientID"].StringValue
- if clientID == nil {
- log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
- return
- }
- res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{
- QueueName: aws.String("snowflake-client-" + *clientID),
- })
- if err != nil {
- log.Printf("SQSHandler: error encountered when creating answer queue for client %s: %v\n", *clientID, err)
- return
- }
- answerSQSURL := res.QueueUrl
- encPollReq = []byte(*message.Body)
- // Get best guess Client IP for geolocating
- remoteAddr := ""
- req, err := messages.DecodeClientPollRequest(encPollReq)
- if err != nil {
- log.Printf("SQSHandler: error encounted when decoding client poll request %s: %v\n", *clientID, err)
- } else {
- sdp, err := util.DeserializeSessionDescription(req.Offer)
- if err != nil {
- log.Printf("SQSHandler: error encounted when deserializing session desc %s: %v\n", *clientID, err)
- } else {
- candidateAddrs := util.GetCandidateAddrs(sdp.SDP)
- if len(candidateAddrs) > 0 {
- remoteAddr = candidateAddrs[0].String()
- }
- }
- }
- arg := messages.Arg{
- Body: encPollReq,
- RemoteAddr: remoteAddr,
- RendezvousMethod: messages.RendezvousSqs,
- }
- err = r.IPC.ClientOffers(arg, &response)
- if err != nil {
- log.Printf("SQSHandler: error encountered when handling message: %v\n", err)
- return
- }
- r.SQSClient.SendMessage(context, &sqs.SendMessageInput{
- QueueUrl: answerSQSURL,
- MessageBody: aws.String(string(response)),
- })
- }
- func (r *sqsHandler) deleteMessage(context context.Context, message *types.Message) {
- r.SQSClient.DeleteMessage(context, &sqs.DeleteMessageInput{
- QueueUrl: r.SQSQueueURL,
- ReceiptHandle: message.ReceiptHandle,
- })
- }
- func newSQSHandler(context context.Context, client sqsclient.SQSClient, sqsQueueName string, region string, i *IPC) (*sqsHandler, error) {
- // Creates the queue if a queue with the same name doesn't exist. If a queue with the same name and attributes
- // already exists, then nothing will happen. If a queue with the same name, but different attributes exists, then
- // an error will be returned
- res, err := client.CreateQueue(context, &sqs.CreateQueueInput{
- QueueName: aws.String(sqsQueueName),
- Attributes: map[string]string{
- "MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10),
- },
- })
- if err != nil {
- return nil, err
- }
- return &sqsHandler{
- SQSClient: client,
- SQSQueueURL: res.QueueUrl,
- IPC: i,
- cleanupInterval: time.Second * 30,
- }, nil
- }
- func (r *sqsHandler) PollAndHandleMessages(ctx context.Context) {
- log.Println("SQSHandler: Starting to poll for messages at: " + *r.SQSQueueURL)
- messagesChn := make(chan *types.Message, 2)
- go r.pollMessages(ctx, messagesChn)
- go r.cleanupClientQueues(ctx)
- for message := range messagesChn {
- select {
- case <-ctx.Done():
- // if context is cancelled
- return
- default:
- r.handleMessage(ctx, message)
- r.deleteMessage(ctx, message)
- }
- }
- }
|