3 次代碼提交 9fdde0af4e ... 957d9af076

作者 SHA1 備註 提交日期
  Michael Buesch 957d9af076 Add support for UDP control port 1 周之前
  Michael Buesch b4403c826f Update yanked crate 'url 2.5.3' 1 天之前
  Michael Buesch 9fdde0af4e Add support for UDP control port 1 周之前
共有 3 個文件被更改,包括 136 次插入89 次删除
  1. 16 16
      Cargo.lock
  2. 51 16
      letmein-proto/src/socket.rs
  3. 69 57
      letmeind/src/server.rs

+ 16 - 16
Cargo.lock

@@ -140,9 +140,9 @@ checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7"
 
 [[package]]
 name = "cpufeatures"
-version = "0.2.15"
+version = "0.2.16"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6"
+checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3"
 dependencies = [
  "libc",
 ]
@@ -639,9 +639,9 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
 
 [[package]]
 name = "litemap"
-version = "0.7.3"
+version = "0.7.4"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704"
+checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
 
 [[package]]
 name = "lock_api"
@@ -776,9 +776,9 @@ dependencies = [
 
 [[package]]
 name = "proc-macro2"
-version = "1.0.91"
+version = "1.0.92"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "307e3004becf10f5a6e0d59d20f3cd28231b0e0827a96cd3e0ce6d14bc1e4bb3"
+checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0"
 dependencies = [
  "unicode-ident",
 ]
@@ -1158,9 +1158,9 @@ dependencies = [
 
 [[package]]
 name = "url"
-version = "2.5.3"
+version = "2.5.4"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada"
+checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60"
 dependencies = [
  "form_urlencoded",
  "idna 1.0.3",
@@ -1382,9 +1382,9 @@ checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51"
 
 [[package]]
 name = "yoke"
-version = "0.7.4"
+version = "0.7.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5"
+checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40"
 dependencies = [
  "serde",
  "stable_deref_trait",
@@ -1394,9 +1394,9 @@ dependencies = [
 
 [[package]]
 name = "yoke-derive"
-version = "0.7.4"
+version = "0.7.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95"
+checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
 dependencies = [
  "proc-macro2",
  "quote",
@@ -1427,18 +1427,18 @@ dependencies = [
 
 [[package]]
 name = "zerofrom"
-version = "0.1.4"
+version = "0.1.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55"
+checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e"
 dependencies = [
  "zerofrom-derive",
 ]
 
 [[package]]
 name = "zerofrom-derive"
-version = "0.1.4"
+version = "0.1.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5"
+checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808"
 dependencies = [
  "proc-macro2",
  "quote",

+ 51 - 16
letmein-proto/src/socket.rs

@@ -18,7 +18,10 @@ use std::{
 };
 use tokio::{
     net::{TcpStream, UdpSocket},
-    sync::Mutex,
+    sync::{
+        watch::{channel, Receiver, Sender},
+        Mutex,
+    },
 };
 
 /// One connection for use by [UdpDispatcherRx].
@@ -137,6 +140,26 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcherRx<MSG_SIZE, Q_SIZ
     fn disconnect(&mut self, peer_addr: SocketAddr) {
         self.conn.remove(&peer_addr);
     }
+
+    fn wake_watchers(&self, accept: &Sender<()>, recv: &Sender<()>) {
+        let mut accept_notified = false;
+        let mut recv_notified = false;
+        for conn in self.conn.values() {
+            if !accept_notified && !conn.accepted {
+                // There is an un-accepted connection. Wake watcher.
+                let _ = accept.send(());
+                accept_notified = true;
+            }
+            if !recv_notified && conn.accepted && !conn.rx_queue.is_empty() {
+                // There is queued RX data. Wake watcher.
+                let _ = recv.send(());
+                recv_notified = true;
+            }
+            if accept_notified && recv_notified {
+                break;
+            }
+        }
+    }
 }
 
 /// Simple TX/RX dispatcher for UDP.
@@ -150,6 +173,9 @@ pub struct UdpDispatcher<const MSG_SIZE: usize, const Q_SIZE: usize> {
 
     /// The UDP socket we use for sending and receiving.
     socket: UdpSocket,
+
+    accept_watch: (Sender<()>, Mutex<Receiver<()>>),
+    recv_watch: (Sender<()>, Mutex<Receiver<()>>),
 }
 
 impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcher<MSG_SIZE, Q_SIZE> {
@@ -157,20 +183,29 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcher<MSG_SIZE, Q_SIZE>
     /// with the given UDP socket and
     /// with the given maximum possible number of connections.
     pub fn new(socket: UdpSocket, max_nr_conn: usize) -> Self {
+        let accept_watch = channel(());
+        let recv_watch = channel(());
         Self {
             rx: Mutex::new(UdpDispatcherRx::new(max_nr_conn)),
             socket,
+            accept_watch: (accept_watch.0, Mutex::new(accept_watch.1)),
+            recv_watch: (recv_watch.0, Mutex::new(recv_watch.1)),
         }
     }
 
     /// Asynchronously wait for a new connection.
     pub async fn accept(&self) -> ah::Result<SocketAddr> {
         loop {
-            self.socket
-                .readable()
-                .await
-                .context("Socket await readable")?;
-            if let Some(peer_addr) = self.rx.lock().await.try_accept(&self.socket)? {
+            tokio::select! {
+                _ = self.socket.readable() => (),
+                _ = self.accept_watch.1.lock() => (),
+            }
+
+            let mut rx = self.rx.lock().await;
+
+            let peer_addr = rx.try_accept(&self.socket)?;
+            rx.wake_watchers(&self.accept_watch.0, &self.recv_watch.0);
+            if let Some(peer_addr) = peer_addr {
                 break Ok(peer_addr);
             }
         }
@@ -180,16 +215,16 @@ impl<const MSG_SIZE: usize, const Q_SIZE: usize> UdpDispatcher<MSG_SIZE, Q_SIZE>
     /// peer identified by the IP address + port tuple `peer_addr`.
     pub async fn recv_from(&self, peer_addr: SocketAddr) -> ah::Result<[u8; MSG_SIZE]> {
         loop {
-            self.socket
-                .readable()
-                .await
-                .context("Socket await readable")?;
-            if let Some(buf) = self
-                .rx
-                .lock()
-                .await
-                .try_recv_from(&self.socket, peer_addr)?
-            {
+            tokio::select! {
+                _ = self.socket.readable() => (),
+                _ = self.recv_watch.1.lock() => (),
+            }
+
+            let mut rx = self.rx.lock().await;
+
+            let buf = rx.try_recv_from(&self.socket, peer_addr)?;
+            rx.wake_watchers(&self.accept_watch.0, &self.recv_watch.0);
+            if let Some(buf) = buf {
                 break Ok(buf);
             }
         }

+ 69 - 57
letmeind/src/server.rs

@@ -13,6 +13,7 @@ use letmein_systemd::{systemd_notify_ready, SystemdSocket};
 use std::{
     convert::Infallible,
     net::{Ipv6Addr, SocketAddr},
+    pin::Pin,
     sync::Arc,
     time::Duration,
 };
@@ -71,24 +72,49 @@ impl Drop for Connection {
     }
 }
 
+type TcpJoinHandle = Pin<Box<JoinHandle<ah::Result<(TcpStream, SocketAddr)>>>>;
+
+fn spawn_tcp_accept(tcp: Arc<Option<TcpListener>>) -> TcpJoinHandle {
+    Box::pin(task::spawn(async move {
+        if let Some(tcp) = tcp.as_ref() {
+            return Ok(tcp.accept().await?);
+        }
+        sleep_forever().await;
+        unreachable!();
+    }))
+}
+
+type UdpJoinHandle = Pin<Box<JoinHandle<ah::Result<(Arc<MsgUdpDispatcher>, SocketAddr)>>>>;
+
+fn spawn_udp_accept(udp: Arc<Option<Arc<MsgUdpDispatcher>>>) -> UdpJoinHandle {
+    Box::pin(task::spawn(async move {
+        if let Some(udp) = udp.as_ref() {
+            let peer_addr = udp.accept().await?;
+            return Ok((Arc::clone(udp), peer_addr));
+        }
+        sleep_forever().await;
+        unreachable!();
+    }))
+}
+
 pub struct Server {
     tcp: Arc<Option<TcpListener>>,
+    tcp_join: TcpJoinHandle,
     udp: Arc<Option<Arc<MsgUdpDispatcher>>>,
+    udp_join: UdpJoinHandle,
 }
 
 impl Server {
     pub async fn new(conf: &Config, no_systemd: bool, max_nr_udp_conn: usize) -> ah::Result<Self> {
-        let mut this = Self {
-            tcp: Arc::new(None),
-            udp: Arc::new(None),
-        };
+        let mut tcp = None;
+        let mut udp = None;
 
         // Get socket from systemd?
         if !no_systemd {
             for socket in SystemdSocket::get_all()?.into_iter() {
                 match socket {
                     SystemdSocket::Tcp(listener) => {
-                        if this.tcp.is_some() {
+                        if tcp.is_some() {
                             return Err(err!("Received multiple TCP sockets from systemd."));
                         }
                         if !conf.port().tcp {
@@ -101,13 +127,13 @@ impl Server {
                         listener
                             .set_nonblocking(true)
                             .context("Set socket non-blocking")?;
-                        this.tcp = Arc::new(Some(
+                        tcp = Some(
                             TcpListener::from_std(listener)
                                 .context("Convert std TcpListener to tokio TcpListener")?,
-                        ));
+                        );
                     }
                     SystemdSocket::Udp(socket) => {
-                        if this.udp.is_some() {
+                        if udp.is_some() {
                             return Err(err!("Received multiple UDP sockets from systemd."));
                         }
                         if !conf.port().udp {
@@ -120,11 +146,11 @@ impl Server {
                         socket
                             .set_nonblocking(true)
                             .context("Set socket non-blocking")?;
-                        this.udp = Arc::new(Some(Arc::new(MsgUdpDispatcher::new(
+                        udp = Some(Arc::new(MsgUdpDispatcher::new(
                             UdpSocket::from_std(socket)
                                 .context("Convert std UdpSocket to tokio UdpSocket")?,
                             max_nr_udp_conn,
-                        ))));
+                        )));
                     }
                     _ => {
                         return Err(err!("Received an unusable socket from systemd."));
@@ -132,68 +158,54 @@ impl Server {
                 }
             }
 
-            if this.tcp.is_some() || this.udp.is_some() {
+            if tcp.is_some() || udp.is_some() {
                 systemd_notify_ready()?;
-                return Ok(this);
             }
         }
 
         // Without systemd.
-
-        // TCP bind.
-        if conf.port().tcp {
-            this.tcp = Arc::new(Some(
-                TcpListener::bind((Ipv6Addr::UNSPECIFIED, conf.port().port))
-                    .await
-                    .context("Bind")?,
-            ));
-        }
-        // UDP bind.
-        if conf.port().udp {
-            this.udp = Arc::new(Some(Arc::new(MsgUdpDispatcher::new(
-                UdpSocket::bind((Ipv6Addr::UNSPECIFIED, conf.port().port))
-                    .await
-                    .context("Bind")?,
-                max_nr_udp_conn,
-            ))));
+        if tcp.is_none() && udp.is_none() {
+            // TCP bind.
+            if conf.port().tcp {
+                tcp = Some(
+                    TcpListener::bind((Ipv6Addr::UNSPECIFIED, conf.port().port))
+                        .await
+                        .context("Bind")?,
+                );
+            }
+            // UDP bind.
+            if conf.port().udp {
+                udp = Some(Arc::new(MsgUdpDispatcher::new(
+                    UdpSocket::bind((Ipv6Addr::UNSPECIFIED, conf.port().port))
+                        .await
+                        .context("Bind")?,
+                    max_nr_udp_conn,
+                )));
+            }
         }
 
-        Ok(this)
+        let tcp = Arc::new(tcp);
+        let tcp_join = spawn_tcp_accept(Arc::clone(&tcp));
+        let udp = Arc::new(udp);
+        let udp_join = spawn_udp_accept(Arc::clone(&udp));
+        Ok(Self {
+            tcp,
+            tcp_join,
+            udp,
+            udp_join,
+        })
     }
 
     pub async fn accept(&mut self) -> ah::Result<Connection> {
-        // Async task for accepting a new TCP connection.
-        let join_tcp: JoinHandle<ah::Result<(TcpStream, SocketAddr)>> = task::spawn({
-            let tcp = Arc::clone(&self.tcp);
-            async move {
-                if let Some(tcp) = tcp.as_ref() {
-                    return Ok(tcp.accept().await?);
-                }
-                sleep_forever().await;
-                unreachable!();
-            }
-        });
-
-        // Async task for accepting a new UDP connection.
-        let join_udp: JoinHandle<ah::Result<(Arc<MsgUdpDispatcher>, SocketAddr)>> = task::spawn({
-            let udp = Arc::clone(&self.udp);
-            async move {
-                if let Some(udp) = udp.as_ref() {
-                    return Ok((Arc::clone(udp), udp.accept().await?));
-                }
-                sleep_forever().await;
-                unreachable!();
-            }
-        });
-
-        // Await any one of the accept tasks.
         tokio::select! {
-            result = join_tcp => {
+            result = &mut self.tcp_join => {
+                self.tcp_join = spawn_tcp_accept(Arc::clone(&self.tcp));
                 let (stream, peer_addr) = result??;
                 let ns = MsgNetSocket::from_tcp(stream);
                 Ok(Connection::new(ns, peer_addr)?)
             }
-            result = join_udp => {
+            result = &mut self.udp_join => {
+                self.udp_join = spawn_udp_accept(Arc::clone(&self.udp));
                 let (udp_disp, peer_addr) = result??;
                 let ns = MsgNetSocket::from_udp(udp_disp, peer_addr);
                 Ok(Connection::new(ns, peer_addr)?)