From 88e01fc8286cabfe70f45a35ea0ba6a1552afdf9 Mon Sep 17 00:00:00 2001 From: moznion Date: Sun, 3 Jan 2021 11:19:28 +0900 Subject: [PATCH] Separate server bootstrap sequence between `listen()` and `run()` Initially it uses a channel that given through the `run()` parameter to notify when a server becomes ready, but that doesn't work because it never run the procedure until `await` called. This means if it calls `await`, it blocks the procedure so it cannot consume a channel simultaneously. Thus, it separates bootstrap sequence between `listen()` and `run()`. `listen()`: Start UDP listening. After this function call is finished, the RADIUS server is ready. `run()`: Start a loop to handle the RADIUS requests. --- e2e-test/src/test.rs | 48 +++++++++---------- examples/server.rs | 22 +++++++-- radius/src/server.rs | 107 +++++++++++++++++++++++++++++++------------ 3 files changed, 120 insertions(+), 57 deletions(-) diff --git a/e2e-test/src/test.rs b/e2e-test/src/test.rs index 37b7c69..ee23729 100644 --- a/e2e-test/src/test.rs +++ b/e2e-test/src/test.rs @@ -78,9 +78,9 @@ mod tests { use radius::core::code::Code; use radius::core::packet::Packet; use radius::core::rfc2865; - use radius::server::Server; use crate::test::{LongTimeTakingHandler, MyRequestHandler, MySecretProvider}; + use radius::server::Server; #[tokio::test] async fn test_runner() { @@ -93,18 +93,19 @@ mod tests { let port = 1812; + let mut server = Server::listen( + "0.0.0.0", + port, + 1500, + true, + MyRequestHandler {}, + MySecretProvider {}, + ) + .await + .unwrap(); + let server_proc = tokio::spawn(async move { - Server::run( - "0.0.0.0", - port, - 1500, - true, - MyRequestHandler {}, - MySecretProvider {}, - receiver, - ) - .await - .unwrap(); + server.run(receiver).await.unwrap(); }); let remote_addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap(); @@ -136,18 +137,19 @@ mod tests { let port = 1812; + let mut server = Server::listen( + "0.0.0.0", + port, + 1500, + true, + LongTimeTakingHandler {}, + MySecretProvider {}, + ) + .await + .unwrap(); + let server_proc = tokio::spawn(async move { - Server::run( - "0.0.0.0", - port, - 1500, - true, - LongTimeTakingHandler {}, - MySecretProvider {}, - receiver, - ) - .await - .unwrap(); + server.run(receiver).await.unwrap(); }); let remote_addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap(); diff --git a/examples/server.rs b/examples/server.rs index 59a3e02..6c4afd3 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,8 +1,8 @@ #[macro_use] extern crate log; -use std::io; use std::net::SocketAddr; +use std::{io, process}; use async_trait::async_trait; use tokio::net::UdpSocket; @@ -17,18 +17,30 @@ use radius::server::{RequestHandler, SecretProvider, SecretProviderError, Server async fn main() { env_logger::init(); - let server_future = Server::run( + // start UDP listening + let mut server = Server::listen( "0.0.0.0", 1812, 1500, - true, + false, MyRequestHandler {}, MySecretProvider {}, - signal::ctrl_c(), + ) + .await + .unwrap(); + + // once it has reached here, a RADIUS server is now ready + info!( + "serve is now ready: {}", + server.get_listen_address().unwrap() ); - let result = server_future.await; + // start the loop to handle the RADIUS requests + let result = server.run(signal::ctrl_c()).await; info!("{:?}", result); + if result.is_err() { + process::exit(1); + } } struct MyRequestHandler {} diff --git a/radius/src/server.rs b/radius/src/server.rs index 5370783..03240f9 100644 --- a/radius/src/server.rs +++ b/radius/src/server.rs @@ -12,23 +12,81 @@ use tokio::net::UdpSocket; use crate::core::packet::Packet; use crate::core::request::Request; use std::fmt::Debug; +use std::marker::PhantomData; /// A basic implementation of the RADIUS server. -pub struct Server {} +/// +/// ## Example Usage +/// - https://github.com/moznion/radius-rs/blob/HEAD/examples/server.rs +pub struct Server, U: SecretProvider> { + skip_authenticity_validation: bool, + buf_size: usize, + conn_arc: Arc, + request_handler_arc: Arc, + secret_provider_arc: Arc, + undergoing_requests_lock_arc: Arc>>, + _phantom_return_type: PhantomData, + _phantom_error_type: PhantomData, +} -impl Server { - /// Start listening a UDP socket to process the RAIDUS requests. - pub async fn run, U: SecretProvider>( +impl, U: SecretProvider> Server { + // NOTE: why it separates between `listen()` and `run()`. + // Initially it uses a channel that given through the `run()` parameter to notify when a server becomes ready, + // but that doesn't work because it never run the procedure until `await` called. + // This means if it calls `await`, it blocks the procedure so it cannot consume a channel simultaneously. + // Thus, it separates bootstrap sequence between `listen()` and `run()`. + // `listen()`: Start UDP listening. After this function call is finished, the RADIUS server is ready. + // `run()`: Start a loop to handle the RADIUS requests. + + /// Starts UDP listening for the RADIUS server. + /// After this function call is finished, the RADIUS server becomes ready to handle the requests; + /// then it calls `run()` method for a `Server` instance that returned by this function, + /// it starts RADIUS request handling. + /// + /// ## Parameters + /// + /// - `host` - a host to listen (e.g. `0.0.0.0`) + /// - `port` - a port number to listen (e.g. `1812`) + /// - `buf_size` - a buffer size for receiving the request payload (e.g. `1500`) + /// - `skip_authenticity_validation` - a flag to specify whether to skip the authenticity validation or not. + /// - `request_handler` - a request handler for the RADIUS requests. + /// - `secret_provider` - a provider for shared-secret value. + pub async fn listen( host: &str, port: u16, buf_size: usize, skip_authenticity_validation: bool, request_handler: T, secret_provider: U, - shutdown_trigger: impl Future, - ) -> Result<(), io::Error> { + ) -> Result { + let undergoing_requests_lock_arc = Arc::new(RwLock::new(HashSet::new())); + let request_handler_arc = Arc::new(request_handler); + let secret_provider_arc = Arc::new(secret_provider); + + let address = format!("{}:{}", host, port); + let conn = UdpSocket::bind(address).await?; + let conn_arc = Arc::new(conn); + + Ok(Server { + skip_authenticity_validation, + buf_size, + conn_arc, + request_handler_arc, + secret_provider_arc, + undergoing_requests_lock_arc, + _phantom_return_type: Default::default(), + _phantom_error_type: Default::default(), + }) + } + + /// Starts the RADIUS requests handling. + /// + /// ## Parameters + /// + /// - `shutdown_trigger`: an implementation of the `Future` to interrupt to shutdown the RADIUS server (e.g. `signal::ctrl_c()`) + pub async fn run(&mut self, shutdown_trigger: impl Future) -> Result<(), io::Error> { tokio::select! { - res = Self::run_loop(host, port, buf_size, skip_authenticity_validation, request_handler, secret_provider) => { + res = self.run_loop() => { res } _ = shutdown_trigger => { @@ -38,27 +96,18 @@ impl Server { } } - async fn run_loop, U: SecretProvider>( - host: &str, - port: u16, - buf_size: usize, - skip_authenticity_validation: bool, - request_handler: T, - secret_provider: U, - ) -> Result<(), io::Error> { - let address = format!("{}:{}", host, port); - let conn = UdpSocket::bind(address).await?; + /// Returns the listening address. + pub fn get_listen_address(&self) -> io::Result { + self.conn_arc.local_addr() + } - let conn_arc = Arc::new(conn); - let undergoing_requests_lock_arc = Arc::new(RwLock::new(HashSet::new())); - let request_handler_arc = Arc::new(request_handler); - let secret_provider_arc = Arc::new(secret_provider); + async fn run_loop(&self) -> Result<(), io::Error> { + let mut buf: Vec = vec![Default::default(); self.buf_size]; - let mut buf = vec![Default::default(); buf_size]; loop { - let conn = conn_arc.clone(); - let request_handler = request_handler_arc.clone(); - let secret_provider = secret_provider_arc.clone(); + let conn = self.conn_arc.clone(); + let request_handler = self.request_handler_arc.clone(); + let secret_provider = self.secret_provider_arc.clone(); let (size, remote_addr) = conn.recv_from(&mut buf).await?; @@ -75,7 +124,8 @@ impl Server { } }; - let undergoing_requests_lock = undergoing_requests_lock_arc.clone(); + let undergoing_requests_lock = self.undergoing_requests_lock_arc.clone(); + let skip_authenticity_validation = self.skip_authenticity_validation; tokio::spawn(async move { Self::process_request( @@ -94,7 +144,7 @@ impl Server { } #[allow(clippy::too_many_arguments)] - async fn process_request, U: SecretProvider>( + async fn process_request( conn: Arc, request_data: &[u8], local_addr: SocketAddr, @@ -194,8 +244,7 @@ pub enum SecretProviderError { /// SecretProvider is a provider for secret value. pub trait SecretProvider: 'static + Sync + Send { - /// This method has to implement the generator of the secret value to verify the request of - /// `Accounting-Response`, `Accounting-Response` and `CoA-Request`. + /// This method has to implement the generator of the shared-secret value to verify the request. fn fetch_secret(&self, remote_addr: SocketAddr) -> Result, SecretProviderError>; }