|
@@ -21,20 +21,40 @@ use tokio::{
|
|
|
sync::Mutex,
|
|
|
};
|
|
|
|
|
|
+/// One connection for use by [UdpDispatcherRx].
|
|
|
#[derive(Debug)]
|
|
|
struct UdpConn<const MSG_SIZE: usize, const Q_SIZE: usize> {
|
|
|
+ /// The receive-queue for this connection.
|
|
|
rx_queue: VecDeque<[u8; MSG_SIZE]>,
|
|
|
+
|
|
|
+ /// The peer IP address + source port tuple for this connection.
|
|
|
peer_addr: SocketAddr,
|
|
|
+
|
|
|
+ /// Is this a new connection that has not been accepted, yet?
|
|
|
accepted: bool,
|
|
|
}
|
|
|
|
|
|
+/// Very simple "connection" tracking for UDP.
|
|
|
+///
|
|
|
+/// Tracking is purely based on the peer's IP address and source port.
|
|
|
+/// There are no other advanced TCP-like functionalities.
|
|
|
+///
|
|
|
+/// The maximum number of connections and the maximum number of packets
|
|
|
+/// in the RX queue are limited.
|
|
|
+/// However, there is no timeout mechanism for the connection.
|
|
|
+/// The caller has to take care of timeout detection and handling.
|
|
|
#[derive(Debug)]
|
|
|
struct UdpDispatcherRx<const MSG_SIZE: usize, const Q_SIZE: usize> {
|
|
|
+ /// All active connections.
|
|
|
conn: HashMap<SocketAddr, UdpConn<MSG_SIZE, Q_SIZE>>,
|
|
|
+
|
|
|
+ /// The maximum possible number of connections.
|
|
|
max_nr_conn: usize,
|
|
|
}
|
|
|
|
|
|
impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcherRx<MSG_SIZE, Q_SIZE> {
|
|
|
+ /// Create a new [UdpDispatcherRx]
|
|
|
+ /// with the given maximum possible number of connections.
|
|
|
fn new(max_nr_conn: usize) -> Self {
|
|
|
UdpDispatcherRx {
|
|
|
conn: HashMap::new(),
|
|
@@ -42,6 +62,7 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcherRx<MSG_SIZE, Q_SIZ
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /// Try to receive a new datagram from the socket.
|
|
|
fn try_recv(&mut self, socket: &UdpSocket) -> ah::Result<()> {
|
|
|
let mut buf = [0_u8; MSG_SIZE];
|
|
|
match socket.try_recv_from(&mut buf) {
|
|
@@ -50,6 +71,8 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcherRx<MSG_SIZE, Q_SIZ
|
|
|
return Err(err!("Socket read: Invalid datagram size: {n}"));
|
|
|
}
|
|
|
|
|
|
+ // Add the received datagram to an existing connection
|
|
|
+ // of create a new connection, if there is none, yet.
|
|
|
assert!(self.conn.len() <= self.max_nr_conn);
|
|
|
let conn = self.conn.entry(peer_addr).or_insert_with(|| UdpConn {
|
|
|
rx_queue: VecDeque::new(),
|
|
@@ -57,15 +80,17 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcherRx<MSG_SIZE, Q_SIZ
|
|
|
accepted: false,
|
|
|
});
|
|
|
|
|
|
+ // Check if the RX queue is full
|
|
|
+ // and if not, then push the received datagram to the queue.
|
|
|
if conn.rx_queue.len() >= Q_SIZE {
|
|
|
- self.conn.remove(&peer_addr);
|
|
|
+ self.conn.remove(&peer_addr); // Close connection.
|
|
|
return Err(err!("UDP socket read: RX queue overflow (max={}).", Q_SIZE));
|
|
|
}
|
|
|
-
|
|
|
conn.rx_queue.push_back(buf);
|
|
|
|
|
|
+ // Check if we exceeded the maximum number of connections.
|
|
|
if self.conn.len() > self.max_nr_conn {
|
|
|
- self.conn.remove(&peer_addr);
|
|
|
+ self.conn.remove(&peer_addr); // Close connection.
|
|
|
return Err(err!(
|
|
|
"UDP socket read: Too many connections (max={}).",
|
|
|
self.max_nr_conn
|
|
@@ -78,6 +103,7 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcherRx<MSG_SIZE, Q_SIZ
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /// Get the first not-accepted connection, or None.
|
|
|
fn try_accept(&mut self, socket: &UdpSocket) -> ah::Result<Option<SocketAddr>> {
|
|
|
self.try_recv(socket)?;
|
|
|
for conn in &mut self.conn.values_mut() {
|
|
@@ -89,6 +115,7 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcherRx<MSG_SIZE, Q_SIZ
|
|
|
Ok(None)
|
|
|
}
|
|
|
|
|
|
+ /// Get the oldest element from the RX queue.
|
|
|
fn try_recv_from(
|
|
|
&mut self,
|
|
|
socket: &UdpSocket,
|
|
@@ -102,18 +129,26 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcherRx<MSG_SIZE, Q_SIZ
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /// Disconnect the connection identified by the `peer_addr`.
|
|
|
fn disconnect(&mut self, peer_addr: SocketAddr) {
|
|
|
self.conn.remove(&peer_addr);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+/// Simple TX/RX dispatcher for UDP.
|
|
|
#[derive(Debug)]
|
|
|
pub struct UdpDispatcher<const MSG_SIZE: usize, const Q_SIZE: usize> {
|
|
|
+ /// RX connection tracking.
|
|
|
rx: Mutex<UdpDispatcherRx<MSG_SIZE, Q_SIZE>>,
|
|
|
+
|
|
|
+ /// The UDP socket we use for sending and receiving.
|
|
|
socket: UdpSocket,
|
|
|
}
|
|
|
|
|
|
impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcher<MSG_SIZE, Q_SIZE> {
|
|
|
+ /// Create a new [UdpDispatcher]
|
|
|
+ /// with the given UDP socket and
|
|
|
+ /// with the given maximum possible number of connections.
|
|
|
pub fn new(socket: UdpSocket, max_nr_conn: usize) -> Self {
|
|
|
Self {
|
|
|
rx: Mutex::new(UdpDispatcherRx::new(max_nr_conn)),
|
|
@@ -121,6 +156,7 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcher<MSG_SIZE, Q_SIZE>
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /// Asynchronously wait for a new connection.
|
|
|
pub async fn accept(&self) -> ah::Result<SocketAddr> {
|
|
|
loop {
|
|
|
self.socket
|
|
@@ -133,6 +169,8 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcher<MSG_SIZE, Q_SIZE>
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /// Asynchronously wait for a new datagram from the specified
|
|
|
+ /// peer identified by the IP address + port tuple `peer_addr`.
|
|
|
pub async fn recv_from(&self, peer_addr: SocketAddr) -> ah::Result<[u8; MSG_SIZE]> {
|
|
|
loop {
|
|
|
self.socket
|
|
@@ -150,6 +188,8 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcher<MSG_SIZE, Q_SIZE>
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /// Asynchronously send a datagram `data` to the specified
|
|
|
+ /// peer identified by the UP address + port tuple `peer_addr`.
|
|
|
pub async fn send_to(&self, peer_addr: SocketAddr, data: [u8; MSG_SIZE]) -> ah::Result<()> {
|
|
|
self.socket
|
|
|
.writable()
|
|
@@ -162,17 +202,22 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcher<MSG_SIZE, Q_SIZE>
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
+ /// Disconnect the connection identified by the `peer_addr`.
|
|
|
pub async fn disconnect(&self, peer_addr: SocketAddr) {
|
|
|
self.rx.lock().await.disconnect(peer_addr);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+/// Socket abstraction for sending and receiving data
|
|
|
+/// over a TCP connection.
|
|
|
#[derive(Debug)]
|
|
|
pub struct NetSocketTcp {
|
|
|
stream: TcpStream,
|
|
|
closed: AtomicBool,
|
|
|
}
|
|
|
|
|
|
+/// Socket abstraction for sending and receiving data
|
|
|
+/// over a UDP connection.
|
|
|
#[derive(Debug)]
|
|
|
pub struct NetSocketUdp<const MSG_SIZE: usize, const Q_SIZE: usize> {
|
|
|
disp: Arc<UdpDispatcher<MSG_SIZE, Q_SIZE>>,
|
|
@@ -180,13 +225,19 @@ pub struct NetSocketUdp<const MSG_SIZE: usize, const Q_SIZE: usize> {
|
|
|
closed: AtomicBool,
|
|
|
}
|
|
|
|
|
|
+/// Socket abstraction for sending and receiving data
|
|
|
+/// over a TCP or UDP connection.
|
|
|
#[derive(Debug)]
|
|
|
pub enum NetSocket<const MSG_SIZE: usize, const Q_SIZE: usize> {
|
|
|
+ /// TCP variant.
|
|
|
Tcp(NetSocketTcp),
|
|
|
+
|
|
|
+ /// UDP variant.
|
|
|
Udp(NetSocketUdp<MSG_SIZE, Q_SIZE>),
|
|
|
}
|
|
|
|
|
|
impl<const MSG_SIZE: usize, const Q_SIZE: usize> NetSocket<MSG_SIZE, Q_SIZE> {
|
|
|
+ /// Create a new [NetSocket] from a [TcpStream] connection.
|
|
|
pub fn from_tcp(stream: TcpStream) -> Self {
|
|
|
Self::Tcp(NetSocketTcp {
|
|
|
stream,
|
|
@@ -194,6 +245,8 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> NetSocket<MSG_SIZE, Q_SIZE> {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
+ /// Create a new [NetSocket] from a [UdpDispatcher]
|
|
|
+ /// and the specified connected `peer_addr`.
|
|
|
pub fn from_udp(disp: Arc<UdpDispatcher<MSG_SIZE, Q_SIZE>>, peer_addr: SocketAddr) -> Self {
|
|
|
Self::Udp(NetSocketUdp {
|
|
|
disp,
|
|
@@ -202,6 +255,7 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> NetSocket<MSG_SIZE, Q_SIZE> {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
+ /// Send a message to the connected peer.
|
|
|
pub async fn send(&self, buf: &[u8]) -> ah::Result<()> {
|
|
|
// For good measure, check if we're not closed. But this check is racy.
|
|
|
if self.is_closed() {
|
|
@@ -247,6 +301,7 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> NetSocket<MSG_SIZE, Q_SIZE> {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /// Receive a message from the connected peer.
|
|
|
pub async fn recv(&self) -> ah::Result<Option<[u8; MSG_SIZE]>> {
|
|
|
// For good measure, check if we're not closed. But this check is racy.
|
|
|
if self.is_closed() {
|
|
@@ -289,6 +344,12 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> NetSocket<MSG_SIZE, Q_SIZE> {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /// Close this connection.
|
|
|
+ ///
|
|
|
+ /// This only has an effect on UDP.
|
|
|
+ /// This does not actually close the TCP connection.
|
|
|
+ /// However, it marks both UDP and TCP as closed and no further
|
|
|
+ /// TX/RX can happen.
|
|
|
pub async fn close(&self) {
|
|
|
match self {
|
|
|
Self::Tcp(inner) => {
|
|
@@ -302,6 +363,7 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> NetSocket<MSG_SIZE, Q_SIZE> {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /// Check if this connection is marked as closed.
|
|
|
pub fn is_closed(&self) -> bool {
|
|
|
match self {
|
|
|
Self::Tcp(inner) => inner.closed.load(atomic::Ordering::SeqCst),
|