mirror of
https://github.com/cubixle/radius-rs.git
synced 2026-04-24 21:24:43 +01:00
Fix the logic for the server shutdown_trigger
This commit is contained in:
@@ -9,4 +9,3 @@ pub mod server;
|
||||
pub mod secret_provider;
|
||||
pub mod request_handler;
|
||||
pub mod request;
|
||||
pub mod server_shutdown_trigger;
|
||||
|
||||
@@ -2,6 +2,6 @@ use tokio::net::UdpSocket;
|
||||
|
||||
use crate::request::Request;
|
||||
|
||||
pub trait RequestHandler: Sync + Send {
|
||||
pub trait RequestHandler: 'static + Sync + Send {
|
||||
fn handle_radius_request(&self, conn: &UdpSocket, request: &Request);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,6 @@ pub enum SecretProviderError {
|
||||
FailedFetching(String)
|
||||
}
|
||||
|
||||
pub trait SecretProvider: Sync + Send {
|
||||
pub trait SecretProvider: 'static + Sync + Send {
|
||||
fn fetch_secret(&self, remote_addr: SocketAddr) -> Result<Vec<u8>, SecretProviderError>;
|
||||
}
|
||||
|
||||
135
src/server.rs
135
src/server.rs
@@ -1,5 +1,6 @@
|
||||
use std::borrow::Borrow;
|
||||
use std::collections::HashSet;
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::{Arc, RwLock};
|
||||
@@ -10,80 +11,82 @@ use crate::packet::Packet;
|
||||
use crate::request::Request;
|
||||
use crate::request_handler::RequestHandler;
|
||||
use crate::secret_provider::SecretProvider;
|
||||
use crate::server_shutdown_trigger::ServerShutdownTrigger;
|
||||
|
||||
pub struct Server<T: RequestHandler, U: SecretProvider> {
|
||||
address: String,
|
||||
skip_authenticity_validation: bool,
|
||||
request_handler_arc: Arc<T>,
|
||||
secret_provider_arc: Arc<U>,
|
||||
shutdown_trigger: ServerShutdownTrigger,
|
||||
}
|
||||
pub struct Server {}
|
||||
|
||||
impl<T: RequestHandler, U: SecretProvider> Server<T, U> {
|
||||
pub fn new(host: &str, port: u16, skip_authenticity_validation: bool, request_handler: T, secret_provider: U) -> Self {
|
||||
Self {
|
||||
address: format!("{}:{}", host, port),
|
||||
skip_authenticity_validation,
|
||||
request_handler_arc: Arc::new(request_handler),
|
||||
secret_provider_arc: Arc::new(secret_provider),
|
||||
shutdown_trigger: ServerShutdownTrigger::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(&'static self) -> Result<(), io::Error> {
|
||||
let mut buf = Vec::new();
|
||||
|
||||
let conn_arc = Arc::new(UdpSocket::bind(&self.address).await?);
|
||||
let undergoing_requests_lock_arc = Arc::new(RwLock::new(HashSet::new()));
|
||||
|
||||
loop {
|
||||
let conn = conn_arc.clone();
|
||||
let request_handler = self.request_handler_arc.clone();
|
||||
let secret_provider = self.secret_provider_arc.clone();
|
||||
|
||||
tokio::select! {
|
||||
received = conn.recv_from(&mut buf) => {
|
||||
let (size, remote_addr) = received?;
|
||||
|
||||
let request_data = buf[..size].to_vec();
|
||||
|
||||
let local_addr = match conn.local_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(e) => {
|
||||
error!("failed to get a local address from from a connection; {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let undergoing_requests_lock = undergoing_requests_lock_arc.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
Self::process_request(
|
||||
conn,
|
||||
&request_data,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
undergoing_requests_lock,
|
||||
request_handler,
|
||||
secret_provider,
|
||||
self.skip_authenticity_validation,
|
||||
).await;
|
||||
});
|
||||
}
|
||||
Some(_) = self.shutdown_trigger => {
|
||||
info!("server is shutting down");
|
||||
return Ok(());
|
||||
}
|
||||
impl Server {
|
||||
pub async fn run<T: RequestHandler, U: SecretProvider>(
|
||||
host: &str,
|
||||
port: u16,
|
||||
buf_size: usize,
|
||||
skip_authenticity_validation: bool,
|
||||
request_handler: T,
|
||||
secret_provider: U,
|
||||
shutdown_trigger: impl Future,
|
||||
) -> Result<(), io::Error> {
|
||||
tokio::select! {
|
||||
res = Self::run_loop(host, port, buf_size, skip_authenticity_validation, request_handler, secret_provider) => {
|
||||
res
|
||||
}
|
||||
_ = shutdown_trigger => {
|
||||
info!("server is shutting down");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn trigger_shutdown(&mut self) {
|
||||
self.shutdown_trigger.trigger_shutdown();
|
||||
async fn run_loop<T: RequestHandler, U: SecretProvider>(
|
||||
host: &str,
|
||||
port: u16,
|
||||
buf_size: usize,
|
||||
skip_authenticity_validation: bool,
|
||||
request_handler: T,
|
||||
secret_provider: U,
|
||||
) -> Result<(), io::Error> {
|
||||
let address = format!("{}:{}", host, port);
|
||||
let conn = UdpSocket::bind(address).await?;
|
||||
|
||||
let conn_arc = Arc::new(conn);
|
||||
let undergoing_requests_lock_arc = Arc::new(RwLock::new(HashSet::new()));
|
||||
let request_handler_arc = Arc::new(request_handler);
|
||||
let secret_provider_arc = Arc::new(secret_provider);
|
||||
|
||||
let mut buf = vec![Default::default(); buf_size];
|
||||
loop {
|
||||
let conn = conn_arc.clone();
|
||||
let request_handler = request_handler_arc.clone();
|
||||
let secret_provider = secret_provider_arc.clone();
|
||||
|
||||
let (size, remote_addr) = conn.recv_from(&mut buf).await?;
|
||||
|
||||
let request_data = buf[..size].to_vec();
|
||||
|
||||
let local_addr = match conn.local_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(e) => {
|
||||
error!("failed to get a local address from from a connection; {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let undergoing_requests_lock = undergoing_requests_lock_arc.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
Self::process_request(
|
||||
conn,
|
||||
&request_data,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
undergoing_requests_lock,
|
||||
request_handler,
|
||||
secret_provider,
|
||||
skip_authenticity_validation,
|
||||
).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_request(
|
||||
async fn process_request<T: RequestHandler, U: SecretProvider>(
|
||||
conn: Arc<UdpSocket>,
|
||||
request_data: &Vec<u8>,
|
||||
local_addr: SocketAddr,
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ServerShutdownTrigger {
|
||||
should_shutdown: bool
|
||||
}
|
||||
|
||||
impl ServerShutdownTrigger {
|
||||
pub(crate) fn new() -> Self {
|
||||
ServerShutdownTrigger {
|
||||
should_shutdown: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn trigger_shutdown(&mut self) {
|
||||
self.should_shutdown = true;
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for ServerShutdownTrigger {
|
||||
type Output = Option<()>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.should_shutdown {
|
||||
true => Poll::from(Some(())),
|
||||
false => Poll::from(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user