sctp_unittest.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
  2. /* This Source Code Form is subject to the terms of the Mozilla Public
  3. * License, v. 2.0. If a copy of the MPL was not distributed with this file,
  4. * You can obtain one at http://mozilla.org/MPL/2.0/. */
  5. // Original author: ekr@rtfm.com
  6. #include <iostream>
  7. #include <string>
  8. #include <map>
  9. #include "sigslot.h"
  10. #include "logging.h"
  11. #include "nsNetCID.h"
  12. #include "nsITimer.h"
  13. #include "nsComponentManagerUtils.h"
  14. #include "nsThreadUtils.h"
  15. #include "nsXPCOM.h"
  16. #include "transportflow.h"
  17. #include "transportlayer.h"
  18. #include "transportlayerloopback.h"
  19. #include "runnable_utils.h"
  20. #include "usrsctp.h"
  21. #define GTEST_HAS_RTTI 0
  22. #include "gtest/gtest.h"
  23. #include "gtest_utils.h"
  24. using namespace mozilla;
  25. static bool sctp_logging = false;
  26. static int port_number = 5000;
  27. namespace {
  28. class TransportTestPeer;
  29. class SendPeriodic : public nsITimerCallback {
  30. public:
  31. SendPeriodic(TransportTestPeer *peer, int to_send) :
  32. peer_(peer),
  33. to_send_(to_send) {}
  34. NS_DECL_THREADSAFE_ISUPPORTS
  35. NS_DECL_NSITIMERCALLBACK
  36. protected:
  37. virtual ~SendPeriodic() {}
  38. TransportTestPeer *peer_;
  39. int to_send_;
  40. };
  41. NS_IMPL_ISUPPORTS(SendPeriodic, nsITimerCallback)
  42. class TransportTestPeer : public sigslot::has_slots<> {
  43. public:
  44. TransportTestPeer(std::string name, int local_port, int remote_port,
  45. MtransportTestUtils* utils)
  46. : name_(name), connected_(false),
  47. sent_(0), received_(0),
  48. flow_(new TransportFlow()),
  49. loopback_(new TransportLayerLoopback()),
  50. sctp_(usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, receive_cb, nullptr, 0, nullptr)),
  51. timer_(do_CreateInstance(NS_TIMER_CONTRACTID)),
  52. periodic_(nullptr),
  53. test_utils_(utils) {
  54. std::cerr << "Creating TransportTestPeer; flow=" <<
  55. static_cast<void *>(flow_.get()) <<
  56. " local=" << local_port <<
  57. " remote=" << remote_port << std::endl;
  58. usrsctp_register_address(static_cast<void *>(this));
  59. int r = usrsctp_set_non_blocking(sctp_, 1);
  60. EXPECT_GE(r, 0);
  61. struct linger l;
  62. l.l_onoff = 1;
  63. l.l_linger = 0;
  64. r = usrsctp_setsockopt(sctp_, SOL_SOCKET, SO_LINGER, &l,
  65. (socklen_t)sizeof(l));
  66. EXPECT_GE(r, 0);
  67. struct sctp_event subscription;
  68. memset(&subscription, 0, sizeof(subscription));
  69. subscription.se_assoc_id = SCTP_ALL_ASSOC;
  70. subscription.se_on = 1;
  71. subscription.se_type = SCTP_ASSOC_CHANGE;
  72. r = usrsctp_setsockopt(sctp_, IPPROTO_SCTP, SCTP_EVENT, &subscription,
  73. sizeof(subscription));
  74. EXPECT_GE(r, 0);
  75. memset(&local_addr_, 0, sizeof(local_addr_));
  76. local_addr_.sconn_family = AF_CONN;
  77. #if !defined(__Userspace_os_Linux) && !defined(__Userspace_os_Windows) && !defined(__Userspace_os_Android)
  78. local_addr_.sconn_len = sizeof(struct sockaddr_conn);
  79. #endif
  80. local_addr_.sconn_port = htons(local_port);
  81. local_addr_.sconn_addr = static_cast<void *>(this);
  82. memset(&remote_addr_, 0, sizeof(remote_addr_));
  83. remote_addr_.sconn_family = AF_CONN;
  84. #if !defined(__Userspace_os_Linux) && !defined(__Userspace_os_Windows) && !defined(__Userspace_os_Android)
  85. remote_addr_.sconn_len = sizeof(struct sockaddr_conn);
  86. #endif
  87. remote_addr_.sconn_port = htons(remote_port);
  88. remote_addr_.sconn_addr = static_cast<void *>(this);
  89. nsresult res;
  90. res = loopback_->Init();
  91. EXPECT_EQ((nsresult)NS_OK, res);
  92. }
  93. ~TransportTestPeer() {
  94. std::cerr << "Destroying sctp connection flow=" <<
  95. static_cast<void *>(flow_.get()) << std::endl;
  96. usrsctp_close(sctp_);
  97. usrsctp_deregister_address(static_cast<void *>(this));
  98. test_utils_->sts_target()->Dispatch(WrapRunnable(this,
  99. &TransportTestPeer::Disconnect_s),
  100. NS_DISPATCH_SYNC);
  101. std::cerr << "~TransportTestPeer() completed" << std::endl;
  102. }
  103. void ConnectSocket(TransportTestPeer *peer) {
  104. test_utils_->sts_target()->Dispatch(WrapRunnable(
  105. this, &TransportTestPeer::ConnectSocket_s, peer),
  106. NS_DISPATCH_SYNC);
  107. }
  108. void ConnectSocket_s(TransportTestPeer *peer) {
  109. loopback_->Connect(peer->loopback_);
  110. ASSERT_EQ((nsresult)NS_OK, flow_->PushLayer(loopback_));
  111. flow_->SignalPacketReceived.connect(this, &TransportTestPeer::PacketReceived);
  112. // SCTP here!
  113. ASSERT_TRUE(sctp_);
  114. std::cerr << "Calling usrsctp_bind()" << std::endl;
  115. int r = usrsctp_bind(sctp_, reinterpret_cast<struct sockaddr *>(
  116. &local_addr_), sizeof(local_addr_));
  117. ASSERT_GE(0, r);
  118. std::cerr << "Calling usrsctp_connect()" << std::endl;
  119. r = usrsctp_connect(sctp_, reinterpret_cast<struct sockaddr *>(
  120. &remote_addr_), sizeof(remote_addr_));
  121. ASSERT_GE(0, r);
  122. }
  123. void Disconnect_s() {
  124. if (flow_) {
  125. flow_ = nullptr;
  126. }
  127. }
  128. void Disconnect() {
  129. loopback_->Disconnect();
  130. }
  131. void StartTransfer(size_t to_send) {
  132. periodic_ = new SendPeriodic(this, to_send);
  133. timer_->SetTarget(test_utils_->sts_target());
  134. timer_->InitWithCallback(periodic_, 10, nsITimer::TYPE_REPEATING_SLACK);
  135. }
  136. void SendOne() {
  137. unsigned char buf[100];
  138. memset(buf, sent_ & 0xff, sizeof(buf));
  139. struct sctp_sndinfo info;
  140. info.snd_sid = 1;
  141. info.snd_flags = 0;
  142. info.snd_ppid = 50; // What the heck is this?
  143. info.snd_context = 0;
  144. info.snd_assoc_id = 0;
  145. int r = usrsctp_sendv(sctp_, buf, sizeof(buf), nullptr, 0,
  146. static_cast<void *>(&info),
  147. sizeof(info), SCTP_SENDV_SNDINFO, 0);
  148. ASSERT_TRUE(r >= 0);
  149. ASSERT_EQ(sizeof(buf), (size_t)r);
  150. ++sent_;
  151. }
  152. int sent() const { return sent_; }
  153. int received() const { return received_; }
  154. bool connected() const { return connected_; }
  155. static TransportResult SendPacket_s(const unsigned char* data, size_t len,
  156. const RefPtr<TransportFlow>& flow) {
  157. TransportResult res = flow->SendPacket(data, len);
  158. delete data; // we always allocate
  159. return res;
  160. }
  161. TransportResult SendPacket(const unsigned char* data, size_t len) {
  162. unsigned char *buffer = new unsigned char[len];
  163. memcpy(buffer, data, len);
  164. // Uses DISPATCH_NORMAL to avoid possible deadlocks when we're called
  165. // from MainThread especially during shutdown (same as DataChannels).
  166. // RUN_ON_THREAD short-circuits if already on the STS thread, which is
  167. // normal for most transfers outside of connect() and close(). Passes
  168. // a refptr to flow_ to avoid any async deletion issues (since we can't
  169. // make 'this' into a refptr as it isn't refcounted)
  170. RUN_ON_THREAD(test_utils_->sts_target(), WrapRunnableNM(
  171. &TransportTestPeer::SendPacket_s, buffer, len, flow_),
  172. NS_DISPATCH_NORMAL);
  173. return 0;
  174. }
  175. void PacketReceived(TransportFlow * flow, const unsigned char* data,
  176. size_t len) {
  177. std::cerr << "Received " << len << " bytes" << std::endl;
  178. // Pass the data to SCTP
  179. usrsctp_conninput(static_cast<void *>(this), data, len, 0);
  180. }
  181. // Process SCTP notification
  182. void Notification(union sctp_notification *msg, size_t len) {
  183. ASSERT_EQ(msg->sn_header.sn_length, len);
  184. if (msg->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
  185. struct sctp_assoc_change *change = &msg->sn_assoc_change;
  186. if (change->sac_state == SCTP_COMM_UP) {
  187. std::cerr << "Connection up" << std::endl;
  188. SetConnected(true);
  189. } else {
  190. std::cerr << "Connection down" << std::endl;
  191. SetConnected(false);
  192. }
  193. }
  194. }
  195. void SetConnected(bool state) {
  196. connected_ = state;
  197. }
  198. static int conn_output(void *addr, void *buffer, size_t length, uint8_t tos, uint8_t set_df) {
  199. TransportTestPeer *peer = static_cast<TransportTestPeer *>(addr);
  200. peer->SendPacket(static_cast<unsigned char *>(buffer), length);
  201. return 0;
  202. }
  203. static int receive_cb(struct socket* sock, union sctp_sockstore addr,
  204. void *data, size_t datalen,
  205. struct sctp_rcvinfo rcv, int flags, void *ulp_info) {
  206. TransportTestPeer *me = static_cast<TransportTestPeer *>(
  207. addr.sconn.sconn_addr);
  208. MOZ_ASSERT(me);
  209. if (flags & MSG_NOTIFICATION) {
  210. union sctp_notification *notif =
  211. static_cast<union sctp_notification *>(data);
  212. me->Notification(notif, datalen);
  213. return 0;
  214. }
  215. me->received_ += datalen;
  216. std::cerr << "receive_cb: sock " << sock << " data " << data << "(" << datalen << ") total received bytes = " << me->received_ << std::endl;
  217. return 0;
  218. }
  219. private:
  220. std::string name_;
  221. bool connected_;
  222. size_t sent_;
  223. size_t received_;
  224. RefPtr<TransportFlow> flow_;
  225. TransportLayerLoopback *loopback_;
  226. struct sockaddr_conn local_addr_;
  227. struct sockaddr_conn remote_addr_;
  228. struct socket *sctp_;
  229. nsCOMPtr<nsITimer> timer_;
  230. RefPtr<SendPeriodic> periodic_;
  231. MtransportTestUtils* test_utils_;
  232. };
  233. // Implemented here because it calls a method of TransportTestPeer
  234. NS_IMETHODIMP SendPeriodic::Notify(nsITimer *timer) {
  235. peer_->SendOne();
  236. --to_send_;
  237. if (!to_send_) {
  238. timer->Cancel();
  239. }
  240. return NS_OK;
  241. }
  242. class SctpTransportTest : public MtransportTest {
  243. public:
  244. SctpTransportTest() {
  245. }
  246. ~SctpTransportTest() {
  247. }
  248. static void debug_printf(const char *format, ...) {
  249. va_list ap;
  250. va_start(ap, format);
  251. vprintf(format, ap);
  252. va_end(ap);
  253. }
  254. static void SetUpTestCase() {
  255. if (sctp_logging) {
  256. usrsctp_init(0, &TransportTestPeer::conn_output, debug_printf);
  257. usrsctp_sysctl_set_sctp_debug_on(0xffffffff);
  258. } else {
  259. usrsctp_init(0, &TransportTestPeer::conn_output, nullptr);
  260. }
  261. }
  262. void TearDown() override {
  263. if (p1_)
  264. p1_->Disconnect();
  265. if (p2_)
  266. p2_->Disconnect();
  267. delete p1_;
  268. delete p2_;
  269. MtransportTest::TearDown();
  270. }
  271. void ConnectSocket(int p1port = 0, int p2port = 0) {
  272. if (!p1port)
  273. p1port = port_number++;
  274. if (!p2port)
  275. p2port = port_number++;
  276. p1_ = new TransportTestPeer("P1", p1port, p2port, test_utils_);
  277. p2_ = new TransportTestPeer("P2", p2port, p1port, test_utils_);
  278. p1_->ConnectSocket(p2_);
  279. p2_->ConnectSocket(p1_);
  280. ASSERT_TRUE_WAIT(p1_->connected(), 2000);
  281. ASSERT_TRUE_WAIT(p2_->connected(), 2000);
  282. }
  283. void TestTransfer(int expected = 1) {
  284. std::cerr << "Starting trasnsfer test" << std::endl;
  285. p1_->StartTransfer(expected);
  286. ASSERT_TRUE_WAIT(p1_->sent() == expected, 10000);
  287. ASSERT_TRUE_WAIT(p2_->received() == (expected * 100), 10000);
  288. std::cerr << "P2 received " << p2_->received() << std::endl;
  289. }
  290. protected:
  291. TransportTestPeer *p1_;
  292. TransportTestPeer *p2_;
  293. };
  294. TEST_F(SctpTransportTest, TestConnect) {
  295. ConnectSocket();
  296. }
  297. TEST_F(SctpTransportTest, TestConnectSymmetricalPorts) {
  298. ConnectSocket(5002,5002);
  299. }
  300. TEST_F(SctpTransportTest, TestTransfer) {
  301. ConnectSocket();
  302. TestTransfer(50);
  303. }
  304. } // end namespace