From cf03846e7385b18dff728a76072e126662e43fb1 Mon Sep 17 00:00:00 2001 From: moznion Date: Mon, 23 Nov 2020 15:25:55 +0900 Subject: [PATCH] Fix the logic for the server shutdown_trigger --- src/lib.rs | 1 - src/request_handler.rs | 2 +- src/secret_provider.rs | 2 +- src/server.rs | 135 +++++++++++++++++---------------- src/server_shutdown_trigger.rs | 31 -------- 5 files changed, 71 insertions(+), 100 deletions(-) delete mode 100644 src/server_shutdown_trigger.rs diff --git a/src/lib.rs b/src/lib.rs index 089a66f..2982981 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,4 +9,3 @@ pub mod server; pub mod secret_provider; pub mod request_handler; pub mod request; -pub mod server_shutdown_trigger; diff --git a/src/request_handler.rs b/src/request_handler.rs index ef56a9d..936af7f 100644 --- a/src/request_handler.rs +++ b/src/request_handler.rs @@ -2,6 +2,6 @@ use tokio::net::UdpSocket; use crate::request::Request; -pub trait RequestHandler: Sync + Send { +pub trait RequestHandler: 'static + Sync + Send { fn handle_radius_request(&self, conn: &UdpSocket, request: &Request); } diff --git a/src/secret_provider.rs b/src/secret_provider.rs index 79f3861..7794c6e 100644 --- a/src/secret_provider.rs +++ b/src/secret_provider.rs @@ -8,6 +8,6 @@ pub enum SecretProviderError { FailedFetching(String) } -pub trait SecretProvider: Sync + Send { +pub trait SecretProvider: 'static + Sync + Send { fn fetch_secret(&self, remote_addr: SocketAddr) -> Result, SecretProviderError>; } diff --git a/src/server.rs b/src/server.rs index e60a57d..05735c3 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,6 @@ use std::borrow::Borrow; use std::collections::HashSet; +use std::future::Future; use std::io; use std::net::SocketAddr; use std::sync::{Arc, RwLock}; @@ -10,80 +11,82 @@ use crate::packet::Packet; use crate::request::Request; use crate::request_handler::RequestHandler; use crate::secret_provider::SecretProvider; -use crate::server_shutdown_trigger::ServerShutdownTrigger; -pub struct Server { - address: String, - skip_authenticity_validation: bool, - request_handler_arc: Arc, - secret_provider_arc: Arc, - shutdown_trigger: ServerShutdownTrigger, -} +pub struct Server {} -impl Server { - pub fn new(host: &str, port: u16, skip_authenticity_validation: bool, request_handler: T, secret_provider: U) -> Self { - Self { - address: format!("{}:{}", host, port), - skip_authenticity_validation, - request_handler_arc: Arc::new(request_handler), - secret_provider_arc: Arc::new(secret_provider), - shutdown_trigger: ServerShutdownTrigger::new(), - } - } - - pub async fn run(&'static self) -> Result<(), io::Error> { - let mut buf = Vec::new(); - - let conn_arc = Arc::new(UdpSocket::bind(&self.address).await?); - let undergoing_requests_lock_arc = Arc::new(RwLock::new(HashSet::new())); - - loop { - let conn = conn_arc.clone(); - let request_handler = self.request_handler_arc.clone(); - let secret_provider = self.secret_provider_arc.clone(); - - tokio::select! { - received = conn.recv_from(&mut buf) => { - let (size, remote_addr) = received?; - - let request_data = buf[..size].to_vec(); - - let local_addr = match conn.local_addr() { - Ok(addr) => addr, - Err(e) => { - error!("failed to get a local address from from a connection; {}", e); - continue; - } - }; - - let undergoing_requests_lock = undergoing_requests_lock_arc.clone(); - - tokio::spawn(async move { - Self::process_request( - conn, - &request_data, - local_addr, - remote_addr, - undergoing_requests_lock, - request_handler, - secret_provider, - self.skip_authenticity_validation, - ).await; - }); - } - Some(_) = self.shutdown_trigger => { - info!("server is shutting down"); - return Ok(()); - } +impl Server { + pub async fn run( + host: &str, + port: u16, + buf_size: usize, + skip_authenticity_validation: bool, + request_handler: T, + secret_provider: U, + shutdown_trigger: impl Future, + ) -> Result<(), io::Error> { + tokio::select! { + res = Self::run_loop(host, port, buf_size, skip_authenticity_validation, request_handler, secret_provider) => { + res + } + _ = shutdown_trigger => { + info!("server is shutting down"); + Ok(()) } } } - pub fn trigger_shutdown(&mut self) { - self.shutdown_trigger.trigger_shutdown(); + async fn run_loop( + host: &str, + port: u16, + buf_size: usize, + skip_authenticity_validation: bool, + request_handler: T, + secret_provider: U, + ) -> Result<(), io::Error> { + let address = format!("{}:{}", host, port); + let conn = UdpSocket::bind(address).await?; + + let conn_arc = Arc::new(conn); + let undergoing_requests_lock_arc = Arc::new(RwLock::new(HashSet::new())); + let request_handler_arc = Arc::new(request_handler); + let secret_provider_arc = Arc::new(secret_provider); + + let mut buf = vec![Default::default(); buf_size]; + loop { + let conn = conn_arc.clone(); + let request_handler = request_handler_arc.clone(); + let secret_provider = secret_provider_arc.clone(); + + let (size, remote_addr) = conn.recv_from(&mut buf).await?; + + let request_data = buf[..size].to_vec(); + + let local_addr = match conn.local_addr() { + Ok(addr) => addr, + Err(e) => { + error!("failed to get a local address from from a connection; {}", e); + continue; + } + }; + + let undergoing_requests_lock = undergoing_requests_lock_arc.clone(); + + tokio::spawn(async move { + Self::process_request( + conn, + &request_data, + local_addr, + remote_addr, + undergoing_requests_lock, + request_handler, + secret_provider, + skip_authenticity_validation, + ).await; + }); + } } - async fn process_request( + async fn process_request( conn: Arc, request_data: &Vec, local_addr: SocketAddr, diff --git a/src/server_shutdown_trigger.rs b/src/server_shutdown_trigger.rs deleted file mode 100644 index 7a876b7..0000000 --- a/src/server_shutdown_trigger.rs +++ /dev/null @@ -1,31 +0,0 @@ -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; - -#[derive(Copy, Clone)] -pub struct ServerShutdownTrigger { - should_shutdown: bool -} - -impl ServerShutdownTrigger { - pub(crate) fn new() -> Self { - ServerShutdownTrigger { - should_shutdown: false, - } - } - - pub(crate) fn trigger_shutdown(&mut self) { - self.should_shutdown = true; - } -} - -impl Future for ServerShutdownTrigger { - type Output = Option<()>; - - fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { - match self.should_shutdown { - true => Poll::from(Some(())), - false => Poll::from(None), - } - } -}