mirror of
https://github.com/cubixle/radius-rs.git
synced 2026-04-24 22:54:43 +01:00
Separate server bootstrap sequence between listen() and run()
Initially it uses a channel that given through the `run()` parameter to notify when a server becomes ready, but that doesn't work because it never run the procedure until `await` called. This means if it calls `await`, it blocks the procedure so it cannot consume a channel simultaneously. Thus, it separates bootstrap sequence between `listen()` and `run()`. `listen()`: Start UDP listening. After this function call is finished, the RADIUS server is ready. `run()`: Start a loop to handle the RADIUS requests.
This commit is contained in:
@@ -78,9 +78,9 @@ mod tests {
|
|||||||
use radius::core::code::Code;
|
use radius::core::code::Code;
|
||||||
use radius::core::packet::Packet;
|
use radius::core::packet::Packet;
|
||||||
use radius::core::rfc2865;
|
use radius::core::rfc2865;
|
||||||
use radius::server::Server;
|
|
||||||
|
|
||||||
use crate::test::{LongTimeTakingHandler, MyRequestHandler, MySecretProvider};
|
use crate::test::{LongTimeTakingHandler, MyRequestHandler, MySecretProvider};
|
||||||
|
use radius::server::Server;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_runner() {
|
async fn test_runner() {
|
||||||
@@ -93,18 +93,19 @@ mod tests {
|
|||||||
|
|
||||||
let port = 1812;
|
let port = 1812;
|
||||||
|
|
||||||
|
let mut server = Server::listen(
|
||||||
|
"0.0.0.0",
|
||||||
|
port,
|
||||||
|
1500,
|
||||||
|
true,
|
||||||
|
MyRequestHandler {},
|
||||||
|
MySecretProvider {},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let server_proc = tokio::spawn(async move {
|
let server_proc = tokio::spawn(async move {
|
||||||
Server::run(
|
server.run(receiver).await.unwrap();
|
||||||
"0.0.0.0",
|
|
||||||
port,
|
|
||||||
1500,
|
|
||||||
true,
|
|
||||||
MyRequestHandler {},
|
|
||||||
MySecretProvider {},
|
|
||||||
receiver,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let remote_addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();
|
let remote_addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();
|
||||||
@@ -136,18 +137,19 @@ mod tests {
|
|||||||
|
|
||||||
let port = 1812;
|
let port = 1812;
|
||||||
|
|
||||||
|
let mut server = Server::listen(
|
||||||
|
"0.0.0.0",
|
||||||
|
port,
|
||||||
|
1500,
|
||||||
|
true,
|
||||||
|
LongTimeTakingHandler {},
|
||||||
|
MySecretProvider {},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let server_proc = tokio::spawn(async move {
|
let server_proc = tokio::spawn(async move {
|
||||||
Server::run(
|
server.run(receiver).await.unwrap();
|
||||||
"0.0.0.0",
|
|
||||||
port,
|
|
||||||
1500,
|
|
||||||
true,
|
|
||||||
LongTimeTakingHandler {},
|
|
||||||
MySecretProvider {},
|
|
||||||
receiver,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let remote_addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();
|
let remote_addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate log;
|
extern crate log;
|
||||||
|
|
||||||
use std::io;
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::{io, process};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::UdpSocket;
|
||||||
@@ -17,18 +17,30 @@ use radius::server::{RequestHandler, SecretProvider, SecretProviderError, Server
|
|||||||
async fn main() {
|
async fn main() {
|
||||||
env_logger::init();
|
env_logger::init();
|
||||||
|
|
||||||
let server_future = Server::run(
|
// start UDP listening
|
||||||
|
let mut server = Server::listen(
|
||||||
"0.0.0.0",
|
"0.0.0.0",
|
||||||
1812,
|
1812,
|
||||||
1500,
|
1500,
|
||||||
true,
|
false,
|
||||||
MyRequestHandler {},
|
MyRequestHandler {},
|
||||||
MySecretProvider {},
|
MySecretProvider {},
|
||||||
signal::ctrl_c(),
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// once it has reached here, a RADIUS server is now ready
|
||||||
|
info!(
|
||||||
|
"serve is now ready: {}",
|
||||||
|
server.get_listen_address().unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = server_future.await;
|
// start the loop to handle the RADIUS requests
|
||||||
|
let result = server.run(signal::ctrl_c()).await;
|
||||||
info!("{:?}", result);
|
info!("{:?}", result);
|
||||||
|
if result.is_err() {
|
||||||
|
process::exit(1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MyRequestHandler {}
|
struct MyRequestHandler {}
|
||||||
|
|||||||
@@ -12,23 +12,81 @@ use tokio::net::UdpSocket;
|
|||||||
use crate::core::packet::Packet;
|
use crate::core::packet::Packet;
|
||||||
use crate::core::request::Request;
|
use crate::core::request::Request;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
/// A basic implementation of the RADIUS server.
|
/// A basic implementation of the RADIUS server.
|
||||||
pub struct Server {}
|
///
|
||||||
|
/// ## Example Usage
|
||||||
|
/// - https://github.com/moznion/radius-rs/blob/HEAD/examples/server.rs
|
||||||
|
pub struct Server<X, E: Debug, T: RequestHandler<X, E>, U: SecretProvider> {
|
||||||
|
skip_authenticity_validation: bool,
|
||||||
|
buf_size: usize,
|
||||||
|
conn_arc: Arc<UdpSocket>,
|
||||||
|
request_handler_arc: Arc<T>,
|
||||||
|
secret_provider_arc: Arc<U>,
|
||||||
|
undergoing_requests_lock_arc: Arc<RwLock<HashSet<RequestKey>>>,
|
||||||
|
_phantom_return_type: PhantomData<X>,
|
||||||
|
_phantom_error_type: PhantomData<E>,
|
||||||
|
}
|
||||||
|
|
||||||
impl Server {
|
impl<X, E: Debug, T: RequestHandler<X, E>, U: SecretProvider> Server<X, E, T, U> {
|
||||||
/// Start listening a UDP socket to process the RAIDUS requests.
|
// NOTE: why it separates between `listen()` and `run()`.
|
||||||
pub async fn run<X, E: Debug, T: RequestHandler<X, E>, U: SecretProvider>(
|
// Initially it uses a channel that given through the `run()` parameter to notify when a server becomes ready,
|
||||||
|
// but that doesn't work because it never run the procedure until `await` called.
|
||||||
|
// This means if it calls `await`, it blocks the procedure so it cannot consume a channel simultaneously.
|
||||||
|
// Thus, it separates bootstrap sequence between `listen()` and `run()`.
|
||||||
|
// `listen()`: Start UDP listening. After this function call is finished, the RADIUS server is ready.
|
||||||
|
// `run()`: Start a loop to handle the RADIUS requests.
|
||||||
|
|
||||||
|
/// Starts UDP listening for the RADIUS server.
|
||||||
|
/// After this function call is finished, the RADIUS server becomes ready to handle the requests;
|
||||||
|
/// then it calls `run()` method for a `Server` instance that returned by this function,
|
||||||
|
/// it starts RADIUS request handling.
|
||||||
|
///
|
||||||
|
/// ## Parameters
|
||||||
|
///
|
||||||
|
/// - `host` - a host to listen (e.g. `0.0.0.0`)
|
||||||
|
/// - `port` - a port number to listen (e.g. `1812`)
|
||||||
|
/// - `buf_size` - a buffer size for receiving the request payload (e.g. `1500`)
|
||||||
|
/// - `skip_authenticity_validation` - a flag to specify whether to skip the authenticity validation or not.
|
||||||
|
/// - `request_handler` - a request handler for the RADIUS requests.
|
||||||
|
/// - `secret_provider` - a provider for shared-secret value.
|
||||||
|
pub async fn listen(
|
||||||
host: &str,
|
host: &str,
|
||||||
port: u16,
|
port: u16,
|
||||||
buf_size: usize,
|
buf_size: usize,
|
||||||
skip_authenticity_validation: bool,
|
skip_authenticity_validation: bool,
|
||||||
request_handler: T,
|
request_handler: T,
|
||||||
secret_provider: U,
|
secret_provider: U,
|
||||||
shutdown_trigger: impl Future,
|
) -> Result<Self, io::Error> {
|
||||||
) -> Result<(), io::Error> {
|
let undergoing_requests_lock_arc = Arc::new(RwLock::new(HashSet::new()));
|
||||||
|
let request_handler_arc = Arc::new(request_handler);
|
||||||
|
let secret_provider_arc = Arc::new(secret_provider);
|
||||||
|
|
||||||
|
let address = format!("{}:{}", host, port);
|
||||||
|
let conn = UdpSocket::bind(address).await?;
|
||||||
|
let conn_arc = Arc::new(conn);
|
||||||
|
|
||||||
|
Ok(Server {
|
||||||
|
skip_authenticity_validation,
|
||||||
|
buf_size,
|
||||||
|
conn_arc,
|
||||||
|
request_handler_arc,
|
||||||
|
secret_provider_arc,
|
||||||
|
undergoing_requests_lock_arc,
|
||||||
|
_phantom_return_type: Default::default(),
|
||||||
|
_phantom_error_type: Default::default(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Starts the RADIUS requests handling.
|
||||||
|
///
|
||||||
|
/// ## Parameters
|
||||||
|
///
|
||||||
|
/// - `shutdown_trigger`: an implementation of the `Future` to interrupt to shutdown the RADIUS server (e.g. `signal::ctrl_c()`)
|
||||||
|
pub async fn run(&mut self, shutdown_trigger: impl Future) -> Result<(), io::Error> {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
res = Self::run_loop(host, port, buf_size, skip_authenticity_validation, request_handler, secret_provider) => {
|
res = self.run_loop() => {
|
||||||
res
|
res
|
||||||
}
|
}
|
||||||
_ = shutdown_trigger => {
|
_ = shutdown_trigger => {
|
||||||
@@ -38,27 +96,18 @@ impl Server {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_loop<X, E: Debug, T: RequestHandler<X, E>, U: SecretProvider>(
|
/// Returns the listening address.
|
||||||
host: &str,
|
pub fn get_listen_address(&self) -> io::Result<SocketAddr> {
|
||||||
port: u16,
|
self.conn_arc.local_addr()
|
||||||
buf_size: usize,
|
}
|
||||||
skip_authenticity_validation: bool,
|
|
||||||
request_handler: T,
|
|
||||||
secret_provider: U,
|
|
||||||
) -> Result<(), io::Error> {
|
|
||||||
let address = format!("{}:{}", host, port);
|
|
||||||
let conn = UdpSocket::bind(address).await?;
|
|
||||||
|
|
||||||
let conn_arc = Arc::new(conn);
|
async fn run_loop(&self) -> Result<(), io::Error> {
|
||||||
let undergoing_requests_lock_arc = Arc::new(RwLock::new(HashSet::new()));
|
let mut buf: Vec<u8> = vec![Default::default(); self.buf_size];
|
||||||
let request_handler_arc = Arc::new(request_handler);
|
|
||||||
let secret_provider_arc = Arc::new(secret_provider);
|
|
||||||
|
|
||||||
let mut buf = vec![Default::default(); buf_size];
|
|
||||||
loop {
|
loop {
|
||||||
let conn = conn_arc.clone();
|
let conn = self.conn_arc.clone();
|
||||||
let request_handler = request_handler_arc.clone();
|
let request_handler = self.request_handler_arc.clone();
|
||||||
let secret_provider = secret_provider_arc.clone();
|
let secret_provider = self.secret_provider_arc.clone();
|
||||||
|
|
||||||
let (size, remote_addr) = conn.recv_from(&mut buf).await?;
|
let (size, remote_addr) = conn.recv_from(&mut buf).await?;
|
||||||
|
|
||||||
@@ -75,7 +124,8 @@ impl Server {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let undergoing_requests_lock = undergoing_requests_lock_arc.clone();
|
let undergoing_requests_lock = self.undergoing_requests_lock_arc.clone();
|
||||||
|
let skip_authenticity_validation = self.skip_authenticity_validation;
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
Self::process_request(
|
Self::process_request(
|
||||||
@@ -94,7 +144,7 @@ impl Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn process_request<X, E: Debug, T: RequestHandler<X, E>, U: SecretProvider>(
|
async fn process_request(
|
||||||
conn: Arc<UdpSocket>,
|
conn: Arc<UdpSocket>,
|
||||||
request_data: &[u8],
|
request_data: &[u8],
|
||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
@@ -194,8 +244,7 @@ pub enum SecretProviderError {
|
|||||||
|
|
||||||
/// SecretProvider is a provider for secret value.
|
/// SecretProvider is a provider for secret value.
|
||||||
pub trait SecretProvider: 'static + Sync + Send {
|
pub trait SecretProvider: 'static + Sync + Send {
|
||||||
/// This method has to implement the generator of the secret value to verify the request of
|
/// This method has to implement the generator of the shared-secret value to verify the request.
|
||||||
/// `Accounting-Response`, `Accounting-Response` and `CoA-Request`.
|
|
||||||
fn fetch_secret(&self, remote_addr: SocketAddr) -> Result<Vec<u8>, SecretProviderError>;
|
fn fetch_secret(&self, remote_addr: SocketAddr) -> Result<Vec<u8>, SecretProviderError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user