From e55679a4d677768dc391c585b3af8852a5f17746 Mon Sep 17 00:00:00 2001 From: moznion Date: Thu, 3 Dec 2020 00:39:12 +0900 Subject: [PATCH] Make the client configurable connection-timeout and socket-timeout --- README.md | 1 - examples/client.rs | 6 ++- radius-client/Cargo.toml | 2 +- radius-client/src/client.rs | 91 +++++++++++++++++++++++++++++-------- 4 files changed, 76 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index b234891..aa2ef93 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,6 @@ Simple example implementations are here: ## Roadmap -- timeout feature on the client - retransmission feature on the client - Support the following RFC dictionaries: - rfc2869 diff --git a/examples/client.rs b/examples/client.rs index 87c1b46..2667612 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -5,6 +5,7 @@ use radius::packet::Packet; use radius::rfc2865; use radius_client::client::Client; use std::net::SocketAddr; +use tokio::time::Duration; #[tokio::main] async fn main() { @@ -14,8 +15,9 @@ async fn main() { let mut req_packet = Packet::new(Code::AccessRequest, &b"secret".to_vec()); rfc2865::add_user_name(&mut req_packet, "admin"); - rfc2865::add_user_password(&mut req_packet, b"p@ssw0rd").unwrap(); // TODO + rfc2865::add_user_password(&mut req_packet, b"p@ssw0rd").unwrap(); - let res = Client::send_packet(&remote_addr, &req_packet).await; + let client = Client::new(Some(Duration::from_secs(3)), Some(Duration::from_secs(5))); + let res = client.send_packet(&remote_addr, &req_packet).await; info!("response: {:?}", res); } diff --git a/radius-client/Cargo.toml b/radius-client/Cargo.toml index 0e9d5fb..e8af504 100644 --- a/radius-client/Cargo.toml +++ b/radius-client/Cargo.toml @@ -9,5 +9,5 @@ keywords = ["radius", "client"] [dependencies] radius = { path = "../radius" } -tokio = { version = "0.3.4", features = ["net"] } +tokio = { version = "0.3.4", features = ["net", "time"] } thiserror = "1.0" diff --git a/radius-client/src/client.rs b/radius-client/src/client.rs index 71643ff..050025b 100644 --- a/radius-client/src/client.rs +++ b/radius-client/src/client.rs @@ -1,13 +1,16 @@ use std::net::SocketAddr; +use std::time::Duration; use thiserror::Error; use tokio::net::UdpSocket; +use tokio::time::timeout; + +use radius::packet::Packet; use crate::client::ClientError::{ FailedConnection, FailedParsingUDPResponse, FailedRadiusPacketEncoding, FailedReceivingResponse, FailedSendingPacket, FailedUdpSocketBinding, }; -use radius::packet::Packet; #[derive(Error, Debug)] pub enum ClientError { @@ -23,14 +26,29 @@ pub enum ClientError { FailedReceivingResponse(String, String), #[error("failed to parse a UDP response into a RADIUS packet => `{0}`")] FailedParsingUDPResponse(String), + #[error("connection timeout")] + ConnectionTimeoutError(), + #[error("socket timeout")] + SocketTimeoutError(), } -pub struct Client {} +pub struct Client { + connection_timeout: Option, + socket_timeout: Option, +} impl Client { const MAX_DATAGRAM_SIZE: usize = 65507; + pub fn new(connection_timeout: Option, socket_timeout: Option) -> Self { + Client { + connection_timeout, + socket_timeout, + } + } + pub async fn send_packet( + &self, remote_addr: &SocketAddr, request_packet: &Packet, ) -> Result { @@ -48,35 +66,68 @@ impl Client { Ok(conn) => conn, Err(e) => return Err(FailedUdpSocketBinding(e.to_string())), }; - match conn.connect(remote_addr).await { - Ok(_) => {} - Err(e) => return Err(FailedConnection(remote_addr.to_string(), e.to_string())), - }; + + match self.connection_timeout { + Some(connection_timeout) => { + match timeout(connection_timeout, self.connect(&conn, remote_addr)).await { + Ok(conn_establish_res) => conn_establish_res, + Err(_) => Err(ClientError::ConnectionTimeoutError()), + } + } + None => self.connect(&conn, remote_addr).await, + }?; let request_data = match request_packet.encode() { Ok(encoded) => encoded, Err(e) => return Err(FailedRadiusPacketEncoding(format!("{:?}", e))), }; - match conn.send(request_data.as_slice()).await { + let response = match self.socket_timeout { + Some(socket_timeout) => { + match timeout( + socket_timeout, + self.request(&conn, &request_data, remote_addr), + ) + .await + { + Ok(response) => response, + Err(_) => Err(ClientError::SocketTimeoutError()), + } + } + None => self.request(&conn, &request_data, remote_addr).await, + }?; + + match Packet::decode(&response.to_vec(), request_packet.get_secret()) { + Ok(response_packet) => Ok(response_packet), + Err(e) => Err(FailedParsingUDPResponse(format!("{:?}", e))), + } + } + + async fn connect(&self, conn: &UdpSocket, remote_addr: &SocketAddr) -> Result<(), ClientError> { + match conn.connect(remote_addr).await { + Ok(_) => Ok(()), + Err(e) => Err(FailedConnection(remote_addr.to_string(), e.to_string())), + } + } + + async fn request( + &self, + conn: &UdpSocket, + request_data: &[u8], + remote_addr: &SocketAddr, + ) -> Result, ClientError> { + match conn.send(request_data).await { Ok(_) => {} Err(e) => return Err(FailedSendingPacket(remote_addr.to_string(), e.to_string())), }; let mut buf = vec![0; Self::MAX_DATAGRAM_SIZE]; - let len = match conn.recv(&mut buf).await { - Ok(len) => len, - Err(e) => { - return Err(FailedReceivingResponse( - remote_addr.to_string(), - e.to_string(), - )) - } - }; - - match Packet::decode(&buf[..len].to_vec(), request_packet.get_secret()) { - Ok(response_packet) => Ok(response_packet), - Err(e) => Err(FailedParsingUDPResponse(format!("{:?}", e))), + match conn.recv(&mut buf).await { + Ok(len) => Ok(buf[..len].to_vec()), + Err(e) => Err(FailedReceivingResponse( + remote_addr.to_string(), + e.to_string(), + )), } } }