|
@@ -13,6 +13,7 @@ use letmein_systemd::{systemd_notify_ready, SystemdSocket};
|
|
|
use std::{
|
|
|
convert::Infallible,
|
|
|
net::{Ipv6Addr, SocketAddr},
|
|
|
+ pin::Pin,
|
|
|
sync::Arc,
|
|
|
time::Duration,
|
|
|
};
|
|
@@ -71,24 +72,49 @@ impl Drop for Connection {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+type TcpJoinHandle = Pin<Box<JoinHandle<ah::Result<(TcpStream, SocketAddr)>>>>;
|
|
|
+
|
|
|
+fn spawn_tcp_accept(tcp: Arc<Option<TcpListener>>) -> TcpJoinHandle {
|
|
|
+ Box::pin(task::spawn(async move {
|
|
|
+ if let Some(tcp) = tcp.as_ref() {
|
|
|
+ return Ok(tcp.accept().await?);
|
|
|
+ }
|
|
|
+ sleep_forever().await;
|
|
|
+ unreachable!();
|
|
|
+ }))
|
|
|
+}
|
|
|
+
|
|
|
+type UdpJoinHandle = Pin<Box<JoinHandle<ah::Result<(Arc<MsgUdpDispatcher>, SocketAddr)>>>>;
|
|
|
+
|
|
|
+fn spawn_udp_accept(udp: Arc<Option<Arc<MsgUdpDispatcher>>>) -> UdpJoinHandle {
|
|
|
+ Box::pin(task::spawn(async move {
|
|
|
+ if let Some(udp) = udp.as_ref() {
|
|
|
+ let peer_addr = udp.accept().await?;
|
|
|
+ return Ok((Arc::clone(udp), peer_addr));
|
|
|
+ }
|
|
|
+ sleep_forever().await;
|
|
|
+ unreachable!();
|
|
|
+ }))
|
|
|
+}
|
|
|
+
|
|
|
pub struct Server {
|
|
|
tcp: Arc<Option<TcpListener>>,
|
|
|
+ tcp_join: TcpJoinHandle,
|
|
|
udp: Arc<Option<Arc<MsgUdpDispatcher>>>,
|
|
|
+ udp_join: UdpJoinHandle,
|
|
|
}
|
|
|
|
|
|
impl Server {
|
|
|
pub async fn new(conf: &Config, no_systemd: bool, max_nr_udp_conn: usize) -> ah::Result<Self> {
|
|
|
- let mut this = Self {
|
|
|
- tcp: Arc::new(None),
|
|
|
- udp: Arc::new(None),
|
|
|
- };
|
|
|
+ let mut tcp = None;
|
|
|
+ let mut udp = None;
|
|
|
|
|
|
// Get socket from systemd?
|
|
|
if !no_systemd {
|
|
|
for socket in SystemdSocket::get_all()?.into_iter() {
|
|
|
match socket {
|
|
|
SystemdSocket::Tcp(listener) => {
|
|
|
- if this.tcp.is_some() {
|
|
|
+ if tcp.is_some() {
|
|
|
return Err(err!("Received multiple TCP sockets from systemd."));
|
|
|
}
|
|
|
if !conf.port().tcp {
|
|
@@ -101,13 +127,13 @@ impl Server {
|
|
|
listener
|
|
|
.set_nonblocking(true)
|
|
|
.context("Set socket non-blocking")?;
|
|
|
- this.tcp = Arc::new(Some(
|
|
|
+ tcp = Some(
|
|
|
TcpListener::from_std(listener)
|
|
|
.context("Convert std TcpListener to tokio TcpListener")?,
|
|
|
- ));
|
|
|
+ );
|
|
|
}
|
|
|
SystemdSocket::Udp(socket) => {
|
|
|
- if this.udp.is_some() {
|
|
|
+ if udp.is_some() {
|
|
|
return Err(err!("Received multiple UDP sockets from systemd."));
|
|
|
}
|
|
|
if !conf.port().udp {
|
|
@@ -120,11 +146,11 @@ impl Server {
|
|
|
socket
|
|
|
.set_nonblocking(true)
|
|
|
.context("Set socket non-blocking")?;
|
|
|
- this.udp = Arc::new(Some(Arc::new(MsgUdpDispatcher::new(
|
|
|
+ udp = Some(Arc::new(MsgUdpDispatcher::new(
|
|
|
UdpSocket::from_std(socket)
|
|
|
.context("Convert std UdpSocket to tokio UdpSocket")?,
|
|
|
max_nr_udp_conn,
|
|
|
- ))));
|
|
|
+ )));
|
|
|
}
|
|
|
_ => {
|
|
|
return Err(err!("Received an unusable socket from systemd."));
|
|
@@ -132,68 +158,54 @@ impl Server {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if this.tcp.is_some() || this.udp.is_some() {
|
|
|
+ if tcp.is_some() || udp.is_some() {
|
|
|
systemd_notify_ready()?;
|
|
|
- return Ok(this);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Without systemd.
|
|
|
-
|
|
|
- // TCP bind.
|
|
|
- if conf.port().tcp {
|
|
|
- this.tcp = Arc::new(Some(
|
|
|
- TcpListener::bind((Ipv6Addr::UNSPECIFIED, conf.port().port))
|
|
|
- .await
|
|
|
- .context("Bind")?,
|
|
|
- ));
|
|
|
- }
|
|
|
- // UDP bind.
|
|
|
- if conf.port().udp {
|
|
|
- this.udp = Arc::new(Some(Arc::new(MsgUdpDispatcher::new(
|
|
|
- UdpSocket::bind((Ipv6Addr::UNSPECIFIED, conf.port().port))
|
|
|
- .await
|
|
|
- .context("Bind")?,
|
|
|
- max_nr_udp_conn,
|
|
|
- ))));
|
|
|
+ if tcp.is_none() && udp.is_none() {
|
|
|
+ // TCP bind.
|
|
|
+ if conf.port().tcp {
|
|
|
+ tcp = Some(
|
|
|
+ TcpListener::bind((Ipv6Addr::UNSPECIFIED, conf.port().port))
|
|
|
+ .await
|
|
|
+ .context("Bind")?,
|
|
|
+ );
|
|
|
+ }
|
|
|
+ // UDP bind.
|
|
|
+ if conf.port().udp {
|
|
|
+ udp = Some(Arc::new(MsgUdpDispatcher::new(
|
|
|
+ UdpSocket::bind((Ipv6Addr::UNSPECIFIED, conf.port().port))
|
|
|
+ .await
|
|
|
+ .context("Bind")?,
|
|
|
+ max_nr_udp_conn,
|
|
|
+ )));
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- Ok(this)
|
|
|
+ let tcp = Arc::new(tcp);
|
|
|
+ let tcp_join = spawn_tcp_accept(Arc::clone(&tcp));
|
|
|
+ let udp = Arc::new(udp);
|
|
|
+ let udp_join = spawn_udp_accept(Arc::clone(&udp));
|
|
|
+ Ok(Self {
|
|
|
+ tcp,
|
|
|
+ tcp_join,
|
|
|
+ udp,
|
|
|
+ udp_join,
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
pub async fn accept(&mut self) -> ah::Result<Connection> {
|
|
|
- // Async task for accepting a new TCP connection.
|
|
|
- let join_tcp: JoinHandle<ah::Result<(TcpStream, SocketAddr)>> = task::spawn({
|
|
|
- let tcp = Arc::clone(&self.tcp);
|
|
|
- async move {
|
|
|
- if let Some(tcp) = tcp.as_ref() {
|
|
|
- return Ok(tcp.accept().await?);
|
|
|
- }
|
|
|
- sleep_forever().await;
|
|
|
- unreachable!();
|
|
|
- }
|
|
|
- });
|
|
|
-
|
|
|
- // Async task for accepting a new UDP connection.
|
|
|
- let join_udp: JoinHandle<ah::Result<(Arc<MsgUdpDispatcher>, SocketAddr)>> = task::spawn({
|
|
|
- let udp = Arc::clone(&self.udp);
|
|
|
- async move {
|
|
|
- if let Some(udp) = udp.as_ref() {
|
|
|
- return Ok((Arc::clone(udp), udp.accept().await?));
|
|
|
- }
|
|
|
- sleep_forever().await;
|
|
|
- unreachable!();
|
|
|
- }
|
|
|
- });
|
|
|
-
|
|
|
- // Await any one of the accept tasks.
|
|
|
tokio::select! {
|
|
|
- result = join_tcp => {
|
|
|
+ result = &mut self.tcp_join => {
|
|
|
+ self.tcp_join = spawn_tcp_accept(Arc::clone(&self.tcp));
|
|
|
let (stream, peer_addr) = result??;
|
|
|
let ns = MsgNetSocket::from_tcp(stream);
|
|
|
Ok(Connection::new(ns, peer_addr)?)
|
|
|
}
|
|
|
- result = join_udp => {
|
|
|
+ result = &mut self.udp_join => {
|
|
|
+ self.udp_join = spawn_udp_accept(Arc::clone(&self.udp));
|
|
|
let (udp_disp, peer_addr) = result??;
|
|
|
let ns = MsgNetSocket::from_udp(udp_disp, peer_addr);
|
|
|
Ok(Connection::new(ns, peer_addr)?)
|