TraversalClient.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. // SPDX-License-Identifier: CC0-1.0
  2. #include "Common/TraversalClient.h"
  3. #include <cstddef>
  4. #include <cstring>
  5. #include <string>
  6. #include "Common/CommonTypes.h"
  7. #include "Common/Logging/Log.h"
  8. #include "Common/MsgHandler.h"
  9. #include "Common/Random.h"
  10. #include "Core/NetPlayProto.h"
  11. namespace Common
  12. {
  13. TraversalClient::TraversalClient(ENetHost* netHost, const std::string& server, const u16 port,
  14. const u16 port_alt)
  15. : m_NetHost(netHost), m_Server(server), m_port(port), m_portAlt(port_alt)
  16. {
  17. netHost->intercept = TraversalClient::InterceptCallback;
  18. Reset();
  19. ReconnectToServer();
  20. }
  21. TraversalClient::~TraversalClient() = default;
  22. TraversalHostId TraversalClient::GetHostID() const
  23. {
  24. return m_HostId;
  25. }
  26. TraversalInetAddress TraversalClient::GetExternalAddress() const
  27. {
  28. return m_external_address;
  29. }
  30. TraversalClient::State TraversalClient::GetState() const
  31. {
  32. return m_State;
  33. }
  34. TraversalClient::FailureReason TraversalClient::GetFailureReason() const
  35. {
  36. return m_FailureReason;
  37. }
  38. void TraversalClient::ReconnectToServer()
  39. {
  40. if (enet_address_set_host(&m_ServerAddress, m_Server.c_str()))
  41. {
  42. OnFailure(FailureReason::BadHost);
  43. return;
  44. }
  45. m_ServerAddress.port = m_port;
  46. m_State = State::Connecting;
  47. TraversalPacket hello = {};
  48. hello.type = TraversalPacketType::HelloFromClient;
  49. hello.helloFromClient.protoVersion = TraversalProtoVersion;
  50. SendTraversalPacket(hello);
  51. if (m_Client)
  52. m_Client->OnTraversalStateChanged();
  53. }
  54. static ENetAddress MakeENetAddress(const TraversalInetAddress& address)
  55. {
  56. ENetAddress eaddr{};
  57. if (address.isIPV6)
  58. {
  59. eaddr.port = 0; // no support yet :(
  60. }
  61. else
  62. {
  63. eaddr.host = address.address[0];
  64. eaddr.port = ntohs(address.port);
  65. }
  66. return eaddr;
  67. }
  68. void TraversalClient::ConnectToClient(std::string_view host)
  69. {
  70. if (host.size() > sizeof(TraversalHostId))
  71. {
  72. PanicAlertFmt("Host too long");
  73. return;
  74. }
  75. TraversalPacket packet = {};
  76. packet.type = TraversalPacketType::ConnectPlease;
  77. memcpy(packet.connectPlease.hostId.data(), host.data(), host.size());
  78. m_ConnectRequestId = SendTraversalPacket(packet);
  79. m_PendingConnect = true;
  80. }
  81. bool TraversalClient::TestPacket(u8* data, size_t size, ENetAddress* from)
  82. {
  83. if (from->host == m_ServerAddress.host && from->port == m_ServerAddress.port)
  84. {
  85. if (size < sizeof(TraversalPacket))
  86. {
  87. ERROR_LOG_FMT(NETPLAY, "Received too-short traversal packet.");
  88. }
  89. else
  90. {
  91. HandleServerPacket((TraversalPacket*)data);
  92. return true;
  93. }
  94. }
  95. return false;
  96. }
  97. //--Temporary until more of the old netplay branch is moved over
  98. void TraversalClient::Update()
  99. {
  100. ENetEvent netEvent;
  101. if (enet_host_service(m_NetHost, &netEvent, 4) > 0)
  102. {
  103. switch (netEvent.type)
  104. {
  105. case ENET_EVENT_TYPE_RECEIVE:
  106. TestPacket(netEvent.packet->data, netEvent.packet->dataLength, &netEvent.peer->address);
  107. enet_packet_destroy(netEvent.packet);
  108. break;
  109. default:
  110. break;
  111. }
  112. }
  113. HandleResends();
  114. }
  115. void TraversalClient::HandleServerPacket(TraversalPacket* packet)
  116. {
  117. u8 ok = 1;
  118. switch (packet->type)
  119. {
  120. case TraversalPacketType::Ack:
  121. if (!packet->ack.ok)
  122. {
  123. OnFailure(FailureReason::ServerForgotAboutUs);
  124. break;
  125. }
  126. for (auto it = m_OutgoingTraversalPackets.begin(); it != m_OutgoingTraversalPackets.end(); ++it)
  127. {
  128. if (it->packet.requestId == packet->requestId)
  129. {
  130. if (packet->requestId == m_TestRequestId)
  131. HandleTraversalTest();
  132. m_OutgoingTraversalPackets.erase(it);
  133. break;
  134. }
  135. }
  136. break;
  137. case TraversalPacketType::HelloFromServer:
  138. if (!IsConnecting())
  139. break;
  140. if (!packet->helloFromServer.ok)
  141. {
  142. OnFailure(FailureReason::VersionTooOld);
  143. break;
  144. }
  145. m_HostId = packet->helloFromServer.yourHostId;
  146. m_external_address = packet->helloFromServer.yourAddress;
  147. NewTraversalTest();
  148. m_State = State::Connected;
  149. if (m_Client)
  150. m_Client->OnTraversalStateChanged();
  151. break;
  152. case TraversalPacketType::PleaseSendPacket:
  153. {
  154. // security is overrated.
  155. ENetAddress addr = MakeENetAddress(packet->pleaseSendPacket.address);
  156. if (addr.port != 0)
  157. {
  158. char message[] = "Hello from Dolphin Netplay...";
  159. ENetBuffer buf;
  160. buf.data = message;
  161. buf.dataLength = sizeof(message) - 1;
  162. if (m_ttlReady)
  163. {
  164. int oldttl;
  165. enet_socket_get_option(m_NetHost->socket, ENET_SOCKOPT_TTL, &oldttl);
  166. enet_socket_set_option(m_NetHost->socket, ENET_SOCKOPT_TTL, m_ttl);
  167. enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
  168. enet_socket_set_option(m_NetHost->socket, ENET_SOCKOPT_TTL, oldttl);
  169. }
  170. else
  171. {
  172. enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
  173. }
  174. }
  175. else
  176. {
  177. // invalid IPV6
  178. ok = 0;
  179. }
  180. break;
  181. }
  182. case TraversalPacketType::ConnectReady:
  183. case TraversalPacketType::ConnectFailed:
  184. {
  185. if (!m_PendingConnect || packet->connectReady.requestId != m_ConnectRequestId)
  186. break;
  187. m_PendingConnect = false;
  188. if (!m_Client)
  189. break;
  190. if (packet->type == TraversalPacketType::ConnectReady)
  191. m_Client->OnConnectReady(MakeENetAddress(packet->connectReady.address));
  192. else
  193. m_Client->OnConnectFailed(packet->connectFailed.reason);
  194. break;
  195. }
  196. default:
  197. WARN_LOG_FMT(NETPLAY, "Received unknown packet with type {}", static_cast<int>(packet->type));
  198. break;
  199. }
  200. if (packet->type != TraversalPacketType::Ack)
  201. {
  202. TraversalPacket ack = {};
  203. ack.type = TraversalPacketType::Ack;
  204. ack.requestId = packet->requestId;
  205. ack.ack.ok = ok;
  206. ENetBuffer buf;
  207. buf.data = &ack;
  208. buf.dataLength = sizeof(ack);
  209. if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
  210. OnFailure(FailureReason::SocketSendError);
  211. }
  212. }
  213. void TraversalClient::OnFailure(FailureReason reason)
  214. {
  215. m_State = State::Failure;
  216. m_FailureReason = reason;
  217. if (m_Client)
  218. m_Client->OnTraversalStateChanged();
  219. }
  220. void TraversalClient::ResendPacket(OutgoingTraversalPacketInfo* info)
  221. {
  222. bool testPacket =
  223. m_TestSocket != ENET_SOCKET_NULL && info->packet.type == TraversalPacketType::TestPlease;
  224. info->sendTime = enet_time_get();
  225. info->tries++;
  226. ENetBuffer buf;
  227. buf.data = &info->packet;
  228. buf.dataLength = sizeof(info->packet);
  229. if (enet_socket_send(testPacket ? m_TestSocket : m_NetHost->socket, &m_ServerAddress, &buf, 1) ==
  230. -1)
  231. OnFailure(FailureReason::SocketSendError);
  232. }
  233. void TraversalClient::HandleResends()
  234. {
  235. const u32 now = enet_time_get();
  236. for (auto& tpi : m_OutgoingTraversalPackets)
  237. {
  238. if (now - tpi.sendTime >= (u32)(300 * tpi.tries))
  239. {
  240. if (tpi.tries >= 5)
  241. {
  242. OnFailure(FailureReason::ResendTimeout);
  243. m_OutgoingTraversalPackets.clear();
  244. break;
  245. }
  246. else
  247. {
  248. ResendPacket(&tpi);
  249. }
  250. }
  251. }
  252. HandlePing();
  253. }
  254. void TraversalClient::HandlePing()
  255. {
  256. const u32 now = enet_time_get();
  257. if (IsConnected() && now - m_PingTime >= 500)
  258. {
  259. TraversalPacket ping = {};
  260. ping.type = TraversalPacketType::Ping;
  261. ping.ping.hostId = m_HostId;
  262. SendTraversalPacket(ping);
  263. m_PingTime = now;
  264. }
  265. }
  266. void TraversalClient::NewTraversalTest()
  267. {
  268. // create test socket
  269. if (m_TestSocket != ENET_SOCKET_NULL)
  270. enet_socket_destroy(m_TestSocket);
  271. m_TestSocket = enet_socket_create(ENET_SOCKET_TYPE_DATAGRAM);
  272. ENetAddress addr = {ENET_HOST_ANY, 0};
  273. if (m_TestSocket == ENET_SOCKET_NULL || enet_socket_bind(m_TestSocket, &addr) < 0)
  274. {
  275. // error, abort
  276. if (m_TestSocket != ENET_SOCKET_NULL)
  277. {
  278. enet_socket_destroy(m_TestSocket);
  279. m_TestSocket = ENET_SOCKET_NULL;
  280. }
  281. return;
  282. }
  283. enet_socket_set_option(m_TestSocket, ENET_SOCKOPT_NONBLOCK, 1);
  284. // create holepunch packet
  285. TraversalPacket packet = {};
  286. packet.type = TraversalPacketType::Ping;
  287. packet.ping.hostId = m_HostId;
  288. packet.requestId = Common::Random::GenerateValue<TraversalRequestId>();
  289. // create buffer
  290. ENetBuffer buf;
  291. buf.data = &packet;
  292. buf.dataLength = sizeof(packet);
  293. // send to alt port
  294. ENetAddress altAddress = m_ServerAddress;
  295. altAddress.port = m_portAlt;
  296. // set up ttl and send
  297. int oldttl;
  298. enet_socket_get_option(m_TestSocket, ENET_SOCKOPT_TTL, &oldttl);
  299. enet_socket_set_option(m_TestSocket, ENET_SOCKOPT_TTL, m_ttl);
  300. if (enet_socket_send(m_TestSocket, &altAddress, &buf, 1) == -1)
  301. {
  302. // error, abort
  303. enet_socket_destroy(m_TestSocket);
  304. m_TestSocket = ENET_SOCKET_NULL;
  305. return;
  306. }
  307. enet_socket_set_option(m_TestSocket, ENET_SOCKOPT_TTL, oldttl);
  308. // send the test request
  309. packet.type = TraversalPacketType::TestPlease;
  310. m_TestRequestId = SendTraversalPacket(packet);
  311. }
  312. void TraversalClient::HandleTraversalTest()
  313. {
  314. if (m_TestSocket != ENET_SOCKET_NULL)
  315. {
  316. // check for packet on test socket (with timeout)
  317. u32 deadline = enet_time_get() + 50;
  318. u32 waitCondition;
  319. do
  320. {
  321. waitCondition = ENET_SOCKET_WAIT_RECEIVE | ENET_SOCKET_WAIT_INTERRUPT;
  322. u32 currentTime = enet_time_get();
  323. if (currentTime > deadline ||
  324. enet_socket_wait(m_TestSocket, &waitCondition, deadline - currentTime) != 0)
  325. {
  326. // error or timeout, exit the loop and assume test failure
  327. waitCondition = 0;
  328. break;
  329. }
  330. else if (waitCondition & ENET_SOCKET_WAIT_RECEIVE)
  331. {
  332. // try reading the packet and see if it's relevant
  333. ENetAddress raddr;
  334. TraversalPacket packet;
  335. ENetBuffer buf;
  336. buf.data = &packet;
  337. buf.dataLength = sizeof(packet);
  338. int rv = enet_socket_receive(m_TestSocket, &raddr, &buf, 1);
  339. if (rv < 0)
  340. {
  341. // error, exit the loop and assume test failure
  342. waitCondition = 0;
  343. break;
  344. }
  345. else if (rv < int(sizeof(packet)) || raddr.host != m_ServerAddress.host ||
  346. raddr.host != m_portAlt || packet.requestId != m_TestRequestId)
  347. {
  348. // irrelevant packet, ignore
  349. continue;
  350. }
  351. }
  352. } while (waitCondition & ENET_SOCKET_WAIT_INTERRUPT);
  353. // regardless of what happens next, we can throw out the socket
  354. enet_socket_destroy(m_TestSocket);
  355. m_TestSocket = ENET_SOCKET_NULL;
  356. if (waitCondition & ENET_SOCKET_WAIT_RECEIVE)
  357. {
  358. // success, we can stop now
  359. m_ttlReady = true;
  360. m_Client->OnTtlDetermined(m_ttl);
  361. }
  362. else
  363. {
  364. // fail, increment and retry
  365. if (++m_ttl < 32)
  366. NewTraversalTest();
  367. }
  368. }
  369. }
  370. TraversalRequestId TraversalClient::SendTraversalPacket(const TraversalPacket& packet)
  371. {
  372. OutgoingTraversalPacketInfo info;
  373. info.packet = packet;
  374. info.packet.requestId = Common::Random::GenerateValue<TraversalRequestId>();
  375. info.tries = 0;
  376. m_OutgoingTraversalPackets.push_back(info);
  377. ResendPacket(&m_OutgoingTraversalPackets.back());
  378. return info.packet.requestId;
  379. }
  380. void TraversalClient::Reset()
  381. {
  382. m_PendingConnect = false;
  383. m_Client = nullptr;
  384. }
  385. int ENET_CALLBACK TraversalClient::InterceptCallback(ENetHost* host, ENetEvent* event)
  386. {
  387. auto traversalClient = g_TraversalClient.get();
  388. if (traversalClient->TestPacket(host->receivedData, host->receivedDataLength,
  389. &host->receivedAddress) ||
  390. (host->receivedDataLength == 1 && host->receivedData[0] == 0))
  391. {
  392. event->type = static_cast<ENetEventType>(Common::ENet::SKIPPABLE_EVENT);
  393. return 1;
  394. }
  395. return 0;
  396. }
  397. std::unique_ptr<TraversalClient> g_TraversalClient;
  398. ENet::ENetHostPtr g_MainNetHost;
  399. // The settings at the previous TraversalClient reset - notably, we
  400. // need to know not just what port it's on, but whether it was
  401. // explicitly requested.
  402. static std::string g_OldServer;
  403. static u16 g_OldServerPort;
  404. static u16 g_OldServerPortAlt;
  405. static u16 g_OldListenPort;
  406. bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 server_port_alt,
  407. u16 listen_port)
  408. {
  409. if (!g_MainNetHost || !g_TraversalClient || server != g_OldServer ||
  410. server_port != g_OldServerPort || server_port_alt != g_OldServerPortAlt ||
  411. listen_port != g_OldListenPort)
  412. {
  413. g_OldServer = server;
  414. g_OldServerPort = server_port;
  415. g_OldServerPortAlt = server_port_alt;
  416. g_OldListenPort = listen_port;
  417. ENetAddress addr = {ENET_HOST_ANY, listen_port};
  418. auto host = Common::ENet::ENetHostPtr{enet_host_create(&addr, // address
  419. 50, // peerCount
  420. NetPlay::CHANNEL_COUNT, // channelLimit
  421. 0, // incomingBandwidth
  422. 0)}; // outgoingBandwidth
  423. if (!host)
  424. {
  425. g_MainNetHost.reset();
  426. return false;
  427. }
  428. host->mtu = std::min(host->mtu, NetPlay::MAX_ENET_MTU);
  429. g_MainNetHost = std::move(host);
  430. g_TraversalClient.reset(
  431. new TraversalClient(g_MainNetHost.get(), server, server_port, server_port_alt));
  432. }
  433. return true;
  434. }
  435. void ReleaseTraversalClient()
  436. {
  437. if (!g_TraversalClient)
  438. return;
  439. g_TraversalClient.reset();
  440. g_MainNetHost.reset();
  441. }
  442. } // namespace Common