Separate server bootstrap sequence between listen() and run()

Initially it uses a channel that given through the `run()` parameter to notify when a server becomes ready,
but that doesn't work because it never run the procedure until `await` called.
This means if it calls `await`, it blocks the procedure so it cannot consume a channel simultaneously.
Thus, it separates bootstrap sequence between `listen()` and `run()`.
`listen()`: Start UDP listening. After this function call is finished, the RADIUS server is ready.
`run()`: Start a loop to handle the RADIUS requests.
This commit is contained in:
moznion
2021-01-03 11:19:28 +09:00
parent 5c74cf92c7
commit 88e01fc828
3 changed files with 120 additions and 57 deletions

View File

@@ -78,9 +78,9 @@ mod tests {
use radius::core::code::Code; use radius::core::code::Code;
use radius::core::packet::Packet; use radius::core::packet::Packet;
use radius::core::rfc2865; use radius::core::rfc2865;
use radius::server::Server;
use crate::test::{LongTimeTakingHandler, MyRequestHandler, MySecretProvider}; use crate::test::{LongTimeTakingHandler, MyRequestHandler, MySecretProvider};
use radius::server::Server;
#[tokio::test] #[tokio::test]
async fn test_runner() { async fn test_runner() {
@@ -93,18 +93,19 @@ mod tests {
let port = 1812; let port = 1812;
let mut server = Server::listen(
"0.0.0.0",
port,
1500,
true,
MyRequestHandler {},
MySecretProvider {},
)
.await
.unwrap();
let server_proc = tokio::spawn(async move { let server_proc = tokio::spawn(async move {
Server::run( server.run(receiver).await.unwrap();
"0.0.0.0",
port,
1500,
true,
MyRequestHandler {},
MySecretProvider {},
receiver,
)
.await
.unwrap();
}); });
let remote_addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap(); let remote_addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();
@@ -136,18 +137,19 @@ mod tests {
let port = 1812; let port = 1812;
let mut server = Server::listen(
"0.0.0.0",
port,
1500,
true,
LongTimeTakingHandler {},
MySecretProvider {},
)
.await
.unwrap();
let server_proc = tokio::spawn(async move { let server_proc = tokio::spawn(async move {
Server::run( server.run(receiver).await.unwrap();
"0.0.0.0",
port,
1500,
true,
LongTimeTakingHandler {},
MySecretProvider {},
receiver,
)
.await
.unwrap();
}); });
let remote_addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap(); let remote_addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();

View File

@@ -1,8 +1,8 @@
#[macro_use] #[macro_use]
extern crate log; extern crate log;
use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::{io, process};
use async_trait::async_trait; use async_trait::async_trait;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
@@ -17,18 +17,30 @@ use radius::server::{RequestHandler, SecretProvider, SecretProviderError, Server
async fn main() { async fn main() {
env_logger::init(); env_logger::init();
let server_future = Server::run( // start UDP listening
let mut server = Server::listen(
"0.0.0.0", "0.0.0.0",
1812, 1812,
1500, 1500,
true, false,
MyRequestHandler {}, MyRequestHandler {},
MySecretProvider {}, MySecretProvider {},
signal::ctrl_c(), )
.await
.unwrap();
// once it has reached here, a RADIUS server is now ready
info!(
"serve is now ready: {}",
server.get_listen_address().unwrap()
); );
let result = server_future.await; // start the loop to handle the RADIUS requests
let result = server.run(signal::ctrl_c()).await;
info!("{:?}", result); info!("{:?}", result);
if result.is_err() {
process::exit(1);
}
} }
struct MyRequestHandler {} struct MyRequestHandler {}

View File

@@ -12,23 +12,81 @@ use tokio::net::UdpSocket;
use crate::core::packet::Packet; use crate::core::packet::Packet;
use crate::core::request::Request; use crate::core::request::Request;
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData;
/// A basic implementation of the RADIUS server. /// A basic implementation of the RADIUS server.
pub struct Server {} ///
/// ## Example Usage
/// - https://github.com/moznion/radius-rs/blob/HEAD/examples/server.rs
pub struct Server<X, E: Debug, T: RequestHandler<X, E>, U: SecretProvider> {
skip_authenticity_validation: bool,
buf_size: usize,
conn_arc: Arc<UdpSocket>,
request_handler_arc: Arc<T>,
secret_provider_arc: Arc<U>,
undergoing_requests_lock_arc: Arc<RwLock<HashSet<RequestKey>>>,
_phantom_return_type: PhantomData<X>,
_phantom_error_type: PhantomData<E>,
}
impl Server { impl<X, E: Debug, T: RequestHandler<X, E>, U: SecretProvider> Server<X, E, T, U> {
/// Start listening a UDP socket to process the RAIDUS requests. // NOTE: why it separates between `listen()` and `run()`.
pub async fn run<X, E: Debug, T: RequestHandler<X, E>, U: SecretProvider>( // Initially it uses a channel that given through the `run()` parameter to notify when a server becomes ready,
// but that doesn't work because it never run the procedure until `await` called.
// This means if it calls `await`, it blocks the procedure so it cannot consume a channel simultaneously.
// Thus, it separates bootstrap sequence between `listen()` and `run()`.
// `listen()`: Start UDP listening. After this function call is finished, the RADIUS server is ready.
// `run()`: Start a loop to handle the RADIUS requests.
/// Starts UDP listening for the RADIUS server.
/// After this function call is finished, the RADIUS server becomes ready to handle the requests;
/// then it calls `run()` method for a `Server` instance that returned by this function,
/// it starts RADIUS request handling.
///
/// ## Parameters
///
/// - `host` - a host to listen (e.g. `0.0.0.0`)
/// - `port` - a port number to listen (e.g. `1812`)
/// - `buf_size` - a buffer size for receiving the request payload (e.g. `1500`)
/// - `skip_authenticity_validation` - a flag to specify whether to skip the authenticity validation or not.
/// - `request_handler` - a request handler for the RADIUS requests.
/// - `secret_provider` - a provider for shared-secret value.
pub async fn listen(
host: &str, host: &str,
port: u16, port: u16,
buf_size: usize, buf_size: usize,
skip_authenticity_validation: bool, skip_authenticity_validation: bool,
request_handler: T, request_handler: T,
secret_provider: U, secret_provider: U,
shutdown_trigger: impl Future, ) -> Result<Self, io::Error> {
) -> Result<(), io::Error> { let undergoing_requests_lock_arc = Arc::new(RwLock::new(HashSet::new()));
let request_handler_arc = Arc::new(request_handler);
let secret_provider_arc = Arc::new(secret_provider);
let address = format!("{}:{}", host, port);
let conn = UdpSocket::bind(address).await?;
let conn_arc = Arc::new(conn);
Ok(Server {
skip_authenticity_validation,
buf_size,
conn_arc,
request_handler_arc,
secret_provider_arc,
undergoing_requests_lock_arc,
_phantom_return_type: Default::default(),
_phantom_error_type: Default::default(),
})
}
/// Starts the RADIUS requests handling.
///
/// ## Parameters
///
/// - `shutdown_trigger`: an implementation of the `Future` to interrupt to shutdown the RADIUS server (e.g. `signal::ctrl_c()`)
pub async fn run(&mut self, shutdown_trigger: impl Future) -> Result<(), io::Error> {
tokio::select! { tokio::select! {
res = Self::run_loop(host, port, buf_size, skip_authenticity_validation, request_handler, secret_provider) => { res = self.run_loop() => {
res res
} }
_ = shutdown_trigger => { _ = shutdown_trigger => {
@@ -38,27 +96,18 @@ impl Server {
} }
} }
async fn run_loop<X, E: Debug, T: RequestHandler<X, E>, U: SecretProvider>( /// Returns the listening address.
host: &str, pub fn get_listen_address(&self) -> io::Result<SocketAddr> {
port: u16, self.conn_arc.local_addr()
buf_size: usize, }
skip_authenticity_validation: bool,
request_handler: T,
secret_provider: U,
) -> Result<(), io::Error> {
let address = format!("{}:{}", host, port);
let conn = UdpSocket::bind(address).await?;
let conn_arc = Arc::new(conn); async fn run_loop(&self) -> Result<(), io::Error> {
let undergoing_requests_lock_arc = Arc::new(RwLock::new(HashSet::new())); let mut buf: Vec<u8> = vec![Default::default(); self.buf_size];
let request_handler_arc = Arc::new(request_handler);
let secret_provider_arc = Arc::new(secret_provider);
let mut buf = vec![Default::default(); buf_size];
loop { loop {
let conn = conn_arc.clone(); let conn = self.conn_arc.clone();
let request_handler = request_handler_arc.clone(); let request_handler = self.request_handler_arc.clone();
let secret_provider = secret_provider_arc.clone(); let secret_provider = self.secret_provider_arc.clone();
let (size, remote_addr) = conn.recv_from(&mut buf).await?; let (size, remote_addr) = conn.recv_from(&mut buf).await?;
@@ -75,7 +124,8 @@ impl Server {
} }
}; };
let undergoing_requests_lock = undergoing_requests_lock_arc.clone(); let undergoing_requests_lock = self.undergoing_requests_lock_arc.clone();
let skip_authenticity_validation = self.skip_authenticity_validation;
tokio::spawn(async move { tokio::spawn(async move {
Self::process_request( Self::process_request(
@@ -94,7 +144,7 @@ impl Server {
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn process_request<X, E: Debug, T: RequestHandler<X, E>, U: SecretProvider>( async fn process_request(
conn: Arc<UdpSocket>, conn: Arc<UdpSocket>,
request_data: &[u8], request_data: &[u8],
local_addr: SocketAddr, local_addr: SocketAddr,
@@ -194,8 +244,7 @@ pub enum SecretProviderError {
/// SecretProvider is a provider for secret value. /// SecretProvider is a provider for secret value.
pub trait SecretProvider: 'static + Sync + Send { pub trait SecretProvider: 'static + Sync + Send {
/// This method has to implement the generator of the secret value to verify the request of /// This method has to implement the generator of the shared-secret value to verify the request.
/// `Accounting-Response`, `Accounting-Response` and `CoA-Request`.
fn fetch_secret(&self, remote_addr: SocketAddr) -> Result<Vec<u8>, SecretProviderError>; fn fetch_secret(&self, remote_addr: SocketAddr) -> Result<Vec<u8>, SecretProviderError>;
} }