connection_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package connection
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/url"
  7. "testing"
  8. "time"
  9. "github.com/gobwas/ws/wsutil"
  10. "github.com/rs/zerolog"
  11. "github.com/stretchr/testify/assert"
  12. )
  13. const (
  14. largeFileSize = 2 * 1024 * 1024
  15. )
  16. var (
  17. testConfig = &Config{
  18. OriginClient: &mockOriginClient{},
  19. GracePeriod: time.Millisecond * 100,
  20. }
  21. log = zerolog.Nop()
  22. testOriginURL = &url.URL{
  23. Scheme: "https",
  24. Host: "connectiontest.argotunnel.com",
  25. }
  26. testLargeResp = make([]byte, largeFileSize)
  27. )
  28. type testRequest struct {
  29. name string
  30. endpoint string
  31. expectedStatus int
  32. expectedBody []byte
  33. isProxyError bool
  34. }
  35. type mockOriginClient struct {
  36. }
  37. func (moc *mockOriginClient) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error {
  38. if isWebsocket {
  39. return wsEndpoint(w, r)
  40. }
  41. switch r.URL.Path {
  42. case "/ok":
  43. originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
  44. case "/large_file":
  45. originRespEndpoint(w, http.StatusOK, testLargeResp)
  46. case "/400":
  47. originRespEndpoint(w, http.StatusBadRequest, []byte(http.StatusText(http.StatusBadRequest)))
  48. case "/500":
  49. originRespEndpoint(w, http.StatusInternalServerError, []byte(http.StatusText(http.StatusInternalServerError)))
  50. case "/error":
  51. return fmt.Errorf("Failed to proxy to origin")
  52. default:
  53. originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
  54. }
  55. return nil
  56. }
  57. type nowriter struct {
  58. io.Reader
  59. }
  60. func (nowriter) Write(p []byte) (int, error) {
  61. return 0, fmt.Errorf("Writer not implemented")
  62. }
  63. func wsEndpoint(w ResponseWriter, r *http.Request) error {
  64. resp := &http.Response{
  65. StatusCode: http.StatusSwitchingProtocols,
  66. }
  67. _ = w.WriteRespHeaders(resp)
  68. clientReader := nowriter{r.Body}
  69. go func() {
  70. for {
  71. data, err := wsutil.ReadClientText(clientReader)
  72. if err != nil {
  73. return
  74. }
  75. if err := wsutil.WriteServerText(w, data); err != nil {
  76. return
  77. }
  78. }
  79. }()
  80. <-r.Context().Done()
  81. return nil
  82. }
  83. func originRespEndpoint(w ResponseWriter, status int, data []byte) {
  84. resp := &http.Response{
  85. StatusCode: status,
  86. }
  87. _ = w.WriteRespHeaders(resp)
  88. _, _ = w.Write(data)
  89. }
  90. type mockConnectedFuse struct{}
  91. func (mcf mockConnectedFuse) Connected() {}
  92. func (mcf mockConnectedFuse) IsConnected() bool {
  93. return true
  94. }
  95. func TestIsEventStream(t *testing.T) {
  96. tests := []struct {
  97. headers http.Header
  98. isEventStream bool
  99. }{
  100. {
  101. headers: newHeader("Content-Type", "text/event-stream"),
  102. isEventStream: true,
  103. },
  104. {
  105. headers: newHeader("content-type", "text/event-stream"),
  106. isEventStream: true,
  107. },
  108. {
  109. headers: newHeader("Content-Type", "text/event-stream; charset=utf-8"),
  110. isEventStream: true,
  111. },
  112. {
  113. headers: newHeader("Content-Type", "application/json"),
  114. isEventStream: false,
  115. },
  116. {
  117. headers: http.Header{},
  118. isEventStream: false,
  119. },
  120. }
  121. for _, test := range tests {
  122. assert.Equal(t, test.isEventStream, IsServerSentEvent(test.headers))
  123. }
  124. }
  125. func newHeader(key, value string) http.Header {
  126. header := http.Header{}
  127. header.Add(key, value)
  128. return header
  129. }