Fix the logic for the server shutdown_trigger

This commit is contained in:
moznion
2020-11-23 15:25:55 +09:00
parent f7485efff3
commit cf03846e73
5 changed files with 71 additions and 100 deletions
-1
View File
@@ -9,4 +9,3 @@ pub mod server;
pub mod secret_provider; pub mod secret_provider;
pub mod request_handler; pub mod request_handler;
pub mod request; pub mod request;
pub mod server_shutdown_trigger;
+1 -1
View File
@@ -2,6 +2,6 @@ use tokio::net::UdpSocket;
use crate::request::Request; use crate::request::Request;
pub trait RequestHandler: Sync + Send { pub trait RequestHandler: 'static + Sync + Send {
fn handle_radius_request(&self, conn: &UdpSocket, request: &Request); fn handle_radius_request(&self, conn: &UdpSocket, request: &Request);
} }
+1 -1
View File
@@ -8,6 +8,6 @@ pub enum SecretProviderError {
FailedFetching(String) FailedFetching(String)
} }
pub trait SecretProvider: Sync + Send { pub trait SecretProvider: 'static + Sync + Send {
fn fetch_secret(&self, remote_addr: SocketAddr) -> Result<Vec<u8>, SecretProviderError>; fn fetch_secret(&self, remote_addr: SocketAddr) -> Result<Vec<u8>, SecretProviderError>;
} }
+38 -35
View File
@@ -1,5 +1,6 @@
use std::borrow::Borrow; use std::borrow::Borrow;
use std::collections::HashSet; use std::collections::HashSet;
use std::future::Future;
use std::io; use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
@@ -10,41 +11,53 @@ use crate::packet::Packet;
use crate::request::Request; use crate::request::Request;
use crate::request_handler::RequestHandler; use crate::request_handler::RequestHandler;
use crate::secret_provider::SecretProvider; use crate::secret_provider::SecretProvider;
use crate::server_shutdown_trigger::ServerShutdownTrigger;
pub struct Server<T: RequestHandler, U: SecretProvider> { pub struct Server {}
address: String,
impl Server {
pub async fn run<T: RequestHandler, U: SecretProvider>(
host: &str,
port: u16,
buf_size: usize,
skip_authenticity_validation: bool, skip_authenticity_validation: bool,
request_handler_arc: Arc<T>, request_handler: T,
secret_provider_arc: Arc<U>, secret_provider: U,
shutdown_trigger: ServerShutdownTrigger, shutdown_trigger: impl Future,
) -> Result<(), io::Error> {
tokio::select! {
res = Self::run_loop(host, port, buf_size, skip_authenticity_validation, request_handler, secret_provider) => {
res
}
_ = shutdown_trigger => {
info!("server is shutting down");
Ok(())
} }
impl<T: RequestHandler, U: SecretProvider> Server<T, U> {
pub fn new(host: &str, port: u16, skip_authenticity_validation: bool, request_handler: T, secret_provider: U) -> Self {
Self {
address: format!("{}:{}", host, port),
skip_authenticity_validation,
request_handler_arc: Arc::new(request_handler),
secret_provider_arc: Arc::new(secret_provider),
shutdown_trigger: ServerShutdownTrigger::new(),
} }
} }
pub async fn run(&'static self) -> Result<(), io::Error> { async fn run_loop<T: RequestHandler, U: SecretProvider>(
let mut buf = Vec::new(); host: &str,
port: u16,
buf_size: usize,
skip_authenticity_validation: bool,
request_handler: T,
secret_provider: U,
) -> Result<(), io::Error> {
let address = format!("{}:{}", host, port);
let conn = UdpSocket::bind(address).await?;
let conn_arc = Arc::new(UdpSocket::bind(&self.address).await?); let conn_arc = Arc::new(conn);
let undergoing_requests_lock_arc = Arc::new(RwLock::new(HashSet::new())); let undergoing_requests_lock_arc = Arc::new(RwLock::new(HashSet::new()));
let request_handler_arc = Arc::new(request_handler);
let secret_provider_arc = Arc::new(secret_provider);
let mut buf = vec![Default::default(); buf_size];
loop { loop {
let conn = conn_arc.clone(); let conn = conn_arc.clone();
let request_handler = self.request_handler_arc.clone(); let request_handler = request_handler_arc.clone();
let secret_provider = self.secret_provider_arc.clone(); let secret_provider = secret_provider_arc.clone();
tokio::select! { let (size, remote_addr) = conn.recv_from(&mut buf).await?;
received = conn.recv_from(&mut buf) => {
let (size, remote_addr) = received?;
let request_data = buf[..size].to_vec(); let request_data = buf[..size].to_vec();
@@ -67,23 +80,13 @@ impl<T: RequestHandler, U: SecretProvider> Server<T, U> {
undergoing_requests_lock, undergoing_requests_lock,
request_handler, request_handler,
secret_provider, secret_provider,
self.skip_authenticity_validation, skip_authenticity_validation,
).await; ).await;
}); });
} }
Some(_) = self.shutdown_trigger => {
info!("server is shutting down");
return Ok(());
}
}
}
} }
pub fn trigger_shutdown(&mut self) { async fn process_request<T: RequestHandler, U: SecretProvider>(
self.shutdown_trigger.trigger_shutdown();
}
async fn process_request(
conn: Arc<UdpSocket>, conn: Arc<UdpSocket>,
request_data: &Vec<u8>, request_data: &Vec<u8>,
local_addr: SocketAddr, local_addr: SocketAddr,
-31
View File
@@ -1,31 +0,0 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
#[derive(Copy, Clone)]
pub struct ServerShutdownTrigger {
should_shutdown: bool
}
impl ServerShutdownTrigger {
pub(crate) fn new() -> Self {
ServerShutdownTrigger {
should_shutdown: false,
}
}
pub(crate) fn trigger_shutdown(&mut self) {
self.should_shutdown = true;
}
}
impl Future for ServerShutdownTrigger {
type Output = Option<()>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.should_shutdown {
true => Poll::from(Some(())),
false => Poll::from(None),
}
}
}