diff --git a/README.md b/README.md index 777c235..8dde807 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,13 @@ To enable a client with IP `X.X.X.X` to receive files from a remote NBD disk, cr --- +If no directory named `/x.x.x.x` or corresponding NBD config `/x.x.x.x.nbd>` is found, the system attempts to read the requested file from `/default>`. This allows all peers to be served with a single file or enables RTFTP to function as a standard TFTP server. + + +Additionally, RTFTP supports proactive setup of NBD connections upon the appearance of an NBD configuration file by utilizing [**inotify**](https://man7.org/linux/man-pages/man7/inotify.7.html) subsystem. With this approach, the remote filesystem is already up and running before the first TFTP request arrives. + +--- + ## Example TFTP root directory layout: @@ -46,11 +53,12 @@ TFTP root directory layout: ``` tftp_root/ ├── 192.168.10.10/ -│ ├── efi/ -│ │ └── grubnetaa64.efi.signed │ └── grub/ │ └── grub.cfg -└── 192.168.10.10.nbd +├── 192.168.10.10.nbd +└── default/ + └── efi/ + └── grubnetaa64.efi.signed ``` Contents of `192.168.10.10.nbd`: @@ -74,15 +82,17 @@ Contents of `192.168.10.10.nbd`: In this example: -- The client with IP `192.168.10.10` will receive `efi/grubnetaa64.efi.signed` and `grub/grub.cfg` from the **local filesystem**. -- Any other files will be retrieved from the **remote NBD disk** at `nbd://10.10.10.10:25000/server_root` from inside `/boot` directory from the **first** partition. - +- The client with IP `192.168.10.10` will receive `efi/grubnetaa64.efi.signed` from the **local filesystem** from the `tftp_root/default` directory +- The client with IP `192.168.10.10` will `grub/grub.cfg` from the **local filesystem** from a specific `tftp_root/192.168.10.10` directory. +- Any other files will be retrieved by the client with IP `192.168.10.10` from the **remote NBD disk** at `nbd://10.10.10.10:25000/server_root` from inside `/boot` directory from the **first** partition. +- Clients with any other IPs will be able to download only `efi/grubnetaa64.efi.signed` from the `tftp_root/192.168.10.10` directory. --- ### Notes: - Only Read Request (RRQ) is supported. - If a file exists in both the local directory and the NBD-based filesystem, the **local file takes precedence**. +- If a file exists in both the `default` directory and a client directory, the latter is downloaded. - Initial setup of the virtual NBD filesystem takes **1.5 to 3 seconds**, so the first request usually need to be retried automatically by the client. - The NBD disk is either: - Connected proactively when config is created to avoid the first read request delay. @@ -92,6 +102,7 @@ In this example: - timeout - blksize - tsize + - windowsize - The daemon is intended to run without root privileges. To allow RTFTP to bind to UDP port 69, one of following workarounds may be applied: - Add **CAP_NET_BIND_SERVICE** capability to RTFTP: `setcap 'cap_net_bind_service=+ep' /path/to/rtftp` - Start RTFTP via `authbind` with port 69 allowed for the RTFTP user: `touch /etc/authbind/byport/69 && chown : /etc/authbind/byport/69` diff --git a/src/cursor/mod.rs b/src/cursor/mod.rs index efc2805..0370087 100644 --- a/src/cursor/mod.rs +++ b/src/cursor/mod.rs @@ -70,7 +70,8 @@ impl<'a> WriteCursor<'a> { if end_index > self.buffer.len() { return Err(BufferError::new("Too little data left to write u16")); } - self.buffer[self.offset..end_index].copy_from_slice(&value.to_be_bytes()); + self.buffer[self.offset] = ((value & 0xFF00) >> 8) as u8; + self.buffer[self.offset + 1] = (value & 0xFF) as u8; self.offset = end_index; Ok(self.offset) } diff --git a/src/datagram_stream.rs b/src/datagram_stream.rs new file mode 100644 index 0000000..90dbe2f --- /dev/null +++ b/src/datagram_stream.rs @@ -0,0 +1,75 @@ +use std::fmt; +use std::fmt::{Debug, Display, Formatter}; +use std::io::ErrorKind; +use std::net::SocketAddr; +use tokio::net::UdpSocket; + +pub(super) struct DatagramStream { + local_socket: UdpSocket, + peer_address: SocketAddr, + display: String, +} + +impl DatagramStream { + pub(super) fn new(local_socket: UdpSocket, peer_address: SocketAddr) -> Self { + let local_address = local_socket.local_addr().unwrap(); + let local_ip = local_address.ip().to_string(); + let local_port = local_address.port().to_string(); + let remote_ip = peer_address.ip().to_string(); + let remote_port = peer_address.port().to_string(); + let display = format!("{local_ip}:{local_port} <=> {remote_ip}:{remote_port}"); + Self { + local_socket, + peer_address, + display, + } + } + + pub(super) fn remote_port(&self) -> u16 { + self.peer_address.port() + } + + pub(super) async fn send(&self, buffer: &[u8]) -> std::io::Result<()> { + match self.local_socket.send_to(buffer, self.peer_address).await { + Ok(sent) => { + if sent != buffer.len() { + Err(ErrorKind::ConnectionReset.into()) + } else { + Ok(()) + } + } + Err(error) => Err(error), + } + } + + pub(super) async fn recv(&self, buffer: &mut [u8], min_size: usize) -> std::io::Result { + loop { + match self.local_socket.recv_from(buffer).await { + Ok((recv_size, remote_address)) => { + if remote_address != self.peer_address { + eprintln!( + "{self}: Ignore datagram {recv_size} long from alien {remote_address}" + ); + } else if recv_size < min_size { + eprintln!("{self}: Ignore runt datagram {recv_size} long"); + } else { + return Ok(recv_size); + } + } + Err(error) => return Err(error), + } + } + } +} + +impl Debug for DatagramStream { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "<{}>", self.display) + } +} + +impl Display for DatagramStream { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "<{}>", self.display) + } +} diff --git a/src/guestfs/mod.rs b/src/guestfs/mod.rs index ecb4c10..452545b 100644 --- a/src/guestfs/mod.rs +++ b/src/guestfs/mod.rs @@ -72,14 +72,6 @@ struct guestfs_stat { ctime: i64, } -impl Drop for guestfs_stat { - fn drop(&mut self) { - unsafe { - guestfs_free_stat(self); - } - } -} - #[link(name = "guestfs")] unsafe extern "C" { fn guestfs_create() -> *const guestfs_h; @@ -223,9 +215,13 @@ impl GuestFS { } } - pub(super) fn add_qemu_option(&self, key: &str, value: &str) -> Result<(), GuestFSError> { - let c_str_key = CString::new(key).expect("CString::new failed"); - let c_str_value = CString::new(value).expect("CString::new failed"); + pub(super) fn add_qemu_option>( + &self, + key: S, + value: S, + ) -> Result<(), GuestFSError> { + let c_str_key = CString::new(key.as_ref()).expect("CString::new failed"); + let c_str_value = CString::new(value.as_ref()).expect("CString::new failed"); if unsafe { guestfs_config(self.handle, c_str_key.as_ptr(), c_str_value.as_ptr()) } == 0 { Ok(()) } else { @@ -270,9 +266,13 @@ impl GuestFS { Ok(partitions_list) } - pub(super) fn mount_ro(&self, device: &str, mountpoint: &str) -> Result<(), GuestFSError> { - let c_str_device = CString::new(device).expect("CString::new failed"); - let c_str_mountpoint = CString::new(mountpoint).expect("CString::new failed"); + pub(super) fn mount_ro>( + &self, + device: S, + mountpoint: S, + ) -> Result<(), GuestFSError> { + let c_str_device = CString::new(device.as_ref()).expect("CString::new failed"); + let c_str_mountpoint = CString::new(mountpoint.as_ref()).expect("CString::new failed"); if unsafe { guestfs_mount_ro( self.handle, @@ -287,8 +287,8 @@ impl GuestFS { } } - pub(super) fn get_size(&self, path: &str) -> Result { - let c_str_path = CString::new(path).expect("CString::new failed"); + pub(super) fn get_size>(&self, path: S) -> Result { + let c_str_path = CString::new(path.as_ref()).expect("CString::new failed"); let size = unsafe { let result = guestfs_stat(self.handle, c_str_path.as_ptr()); if result.is_null() { @@ -301,8 +301,8 @@ impl GuestFS { Ok(size as usize) } - pub(super) fn set_append(&self, string: &str) -> Result<(), GuestFSError> { - let c_str = CString::new(string).expect("CString::new failed"); + pub(super) fn set_append>(&self, string: S) -> Result<(), GuestFSError> { + let c_str = CString::new(string.as_ref()).expect("CString::new failed"); let result = unsafe { guestfs_set_append(self.handle, c_str.as_ptr()) }; if result == 0 { Ok(()) @@ -311,8 +311,12 @@ impl GuestFS { } } - pub(super) fn read_chunk(&self, path: &str, offset: usize) -> Result, GuestFSError> { - let c_str_path = CString::new(path).expect("CString::new failed"); + pub(super) fn read_chunk>( + &self, + path: S, + offset: usize, + ) -> Result, GuestFSError> { + let c_str_path = CString::new(path.as_ref()).expect("CString::new failed"); unsafe { let mut size_r: libc::size_t = 0; let read_buffer = guestfs_pread( diff --git a/src/main.rs b/src/main.rs index e66a61b..608d0e6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,9 @@ +#[cfg(windows)] +compile_error!( + "This project does not support building on Windows due to its reliance on libguestfs and inotify." +); mod cursor; +mod datagram_stream; mod fs; mod fs_watch; mod guestfs; diff --git a/src/messages/mod.rs b/src/messages/mod.rs index 22d6834..595f815 100644 --- a/src/messages/mod.rs +++ b/src/messages/mod.rs @@ -127,9 +127,9 @@ impl OptionsAcknowledge { } } - pub(super) fn serialize(&self, buffer: &mut [u8]) -> Result<(usize, u16), BufferError> { + pub(super) fn serialize(&self, buffer: &mut [u8]) -> Result { if buffer.is_empty() { - return Ok((0, 0)); + return Ok(0); } let mut datagram = WriteCursor::new(buffer); datagram.put_ushort(OACK)?; @@ -141,7 +141,7 @@ impl OptionsAcknowledge { } offset }; - Ok((offset, 0)) + Ok(offset) } pub fn push(&mut self, option: (String, String)) { self.options.push(option) diff --git a/src/options/mod.rs b/src/options/mod.rs index e994076..51d3910 100644 --- a/src/options/mod.rs +++ b/src/options/mod.rs @@ -14,9 +14,13 @@ static TIMEOUT: &str = "timeout"; static BLKSIZE: &str = "blksize"; +const WINDOW_SIZE: &str = "windowsize"; + const BLOCK_SIZE_LIMIT: usize = u16::MAX as usize; -const ACK_TIMEOUT_LIMIT: usize = 60; +const ACK_TIMEOUT_LIMIT: usize = 255; + +const WINDOW_SIZE_LIMIT: usize = u16::MAX as usize; #[derive(Clone)] pub(super) struct Blksize { @@ -28,13 +32,10 @@ impl Blksize { if let Some(block_size_string) = options.get(BLKSIZE) && let Ok(block_size) = block_size_string.parse::() { - if block_size < BLOCK_SIZE_LIMIT { + if (8..=BLOCK_SIZE_LIMIT).contains(&block_size) { return Some(Self { block_size }); } else { - eprintln!( - "Requested block size {block_size} exceeds \ - maximum allowed block size {BLOCK_SIZE_LIMIT}" - ); + eprintln!("Requested {block_size} doesn't fit in range 8 .. ={BLOCK_SIZE_LIMIT}"); } } None @@ -44,16 +45,8 @@ impl Blksize { (String::from(BLKSIZE), self.block_size.to_string()) } - pub(super) fn is_last(&self, chunk_size: usize) -> bool { - chunk_size < self.block_size - } - - pub(super) fn read_chunk( - &self, - opened_file: &mut dyn OpenedFile, - buffer: &mut [u8], - ) -> Result { - opened_file.read_to(&mut buffer[..self.block_size]) + pub(super) fn get_size(&self) -> usize { + self.block_size } } @@ -86,11 +79,11 @@ impl AckTimeout { if let Some(timeout_string) = options.get(TIMEOUT) && let Ok(timeout) = timeout_string.parse::() { - if timeout <= ACK_TIMEOUT_LIMIT { + if (1..=ACK_TIMEOUT_LIMIT).contains(&timeout) { return Some(Self { timeout }); } else { eprintln!( - "Requested timeout {timeout} exceeds maximum allowed {ACK_TIMEOUT_LIMIT}" + "Requested timeout {timeout} doesn't fit in range 1 .. ={ACK_TIMEOUT_LIMIT}" ); } } @@ -126,3 +119,41 @@ impl TSize { (String::from(TSIZE), self.file_size.to_string()) } } + +pub(super) struct WindowSize(usize); + +impl Display for WindowSize { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[blocks_count: {}]", self.0) + } +} + +impl WindowSize { + pub(super) fn find_in(options: &HashMap) -> Option { + if let Some(window_size) = options.get(WINDOW_SIZE) + && let Ok(window_size) = window_size.parse::() + { + if (1..=WINDOW_SIZE_LIMIT).contains(&window_size) { + return Some(Self(window_size)); + } else { + eprintln!( + "Requested window size {window_size} doesn't fit in range 1 .. ={WINDOW_SIZE_LIMIT}" + ); + } + } + None + } + + pub(super) fn get_size(&self) -> usize { + self.0 + } + pub(super) fn as_key_pair(&self) -> (String, String) { + (String::from(WINDOW_SIZE), self.0.to_string()) + } +} + +impl Default for WindowSize { + fn default() -> Self { + Self(1) + } +} diff --git a/src/options/tests.rs b/src/options/tests.rs index f5c27d1..ee848b2 100644 --- a/src/options/tests.rs +++ b/src/options/tests.rs @@ -3,7 +3,7 @@ use super::*; #[test] fn find_block_size() { let mut options = HashMap::new(); - options.insert("blksize".to_string(), "1468".to_string()); + options.insert(BLKSIZE.to_string(), "1468".to_string()); let blk_size = Blksize::find_in(&options).unwrap(); assert_eq!(blk_size.block_size, 1468); assert_eq!( @@ -15,7 +15,7 @@ fn find_block_size() { #[test] fn find_tsize() { let mut options = HashMap::new(); - options.insert("tsize".to_string(), "0".to_string()); + options.insert(TSIZE.to_string(), "0".to_string()); assert!(TSize::is_requested(&options)); } @@ -23,7 +23,7 @@ fn find_tsize() { fn find_timeout() { let mut options = HashMap::new(); let timeout_value: usize = 10; - options.insert("timeout".to_string(), timeout_value.to_string()); + options.insert(TIMEOUT.to_string(), timeout_value.to_string()); let timeout = AckTimeout::find_in(&options).unwrap(); assert_eq!(timeout.timeout, timeout_value); } @@ -31,15 +31,55 @@ fn find_timeout() { #[test] fn test_timeout_cap() { let mut options = HashMap::new(); - options.insert("timeout".to_string(), (ACK_TIMEOUT_LIMIT + 1).to_string()); + options.insert(TIMEOUT.to_string(), (ACK_TIMEOUT_LIMIT + 1).to_string()); let find_result = AckTimeout::find_in(&options); assert!(find_result.is_none()); } +#[test] +fn test_timeout_bottom() { + let mut options = HashMap::new(); + options.insert(TIMEOUT.to_string(), 0.to_string()); + let find_result = AckTimeout::find_in(&options); + assert!(find_result.is_none()); +} + +#[test] +fn test_block_size_bottom() { + let mut options = HashMap::new(); + options.insert(BLKSIZE.to_string(), 7.to_string()); + let find_result = Blksize::find_in(&options); + assert!(find_result.is_none()); +} + #[test] fn test_block_size_cap() { let mut options = HashMap::new(); - options.insert("blksize".to_string(), (BLOCK_SIZE_LIMIT + 1).to_string()); + options.insert(BLKSIZE.to_string(), (BLOCK_SIZE_LIMIT + 1).to_string()); let find_result = Blksize::find_in(&options); assert!(find_result.is_none()); } + +#[test] +fn test_window_size() { + let mut options = HashMap::new(); + options.insert(WINDOW_SIZE.to_string(), 10.to_string()); + let find_result = WindowSize::find_in(&options); + assert!(find_result.is_some()); +} + +#[test] +fn test_window_bottom() { + let mut options = HashMap::new(); + options.insert(WINDOW_SIZE.to_string(), 0.to_string()); + let find_result = WindowSize::find_in(&options); + assert!(find_result.is_none()); +} + +#[test] +fn test_window_cap() { + let mut options = HashMap::new(); + options.insert(WINDOW_SIZE.to_string(), (WINDOW_SIZE_LIMIT + 1).to_string()); + let find_result = WindowSize::find_in(&options); + assert!(find_result.is_none()); +} diff --git a/src/peer_handler.rs b/src/peer_handler/mod.rs similarity index 51% rename from src/peer_handler.rs rename to src/peer_handler/mod.rs index 250d0bc..f3de8d9 100644 --- a/src/peer_handler.rs +++ b/src/peer_handler/mod.rs @@ -1,14 +1,15 @@ -use crate::cursor::{ReadCursor, WriteCursor}; +use crate::cursor::ReadCursor; +use crate::datagram_stream::DatagramStream; use crate::fs::{FileError, OpenedFile, Root}; use crate::local_fs::LocalRoot; use crate::messages::{OptionsAcknowledge, ReadRequest, TFTPError, UNDEFINED_ERROR}; use crate::nbd_disk::NBDConfig; -use crate::options::{AckTimeout, Blksize, TSize}; +use crate::options::{AckTimeout, Blksize, TSize, WindowSize}; use crate::remote_fs::{Config, VirtualRootError}; use serde_json::Value; use std::collections::HashMap; use std::fmt::{Debug, Display, Formatter}; -use std::io::ErrorKind; +use std::io; use std::net::{IpAddr, SocketAddr}; use std::ops::DerefMut; use std::path::{Path, PathBuf}; @@ -22,6 +23,9 @@ use tokio::sync::mpsc::{Receiver, Sender}; use tokio::task::{JoinHandle, LocalSet}; use tokio::time::timeout; +#[cfg(test)] +mod tests; + const ACK: u16 = 0x04; const DATA: u16 = 0x03; @@ -33,169 +37,178 @@ const ACCESS_VIOLATION: u16 = 0x02; const MAX_SESSIONS_PER_IP: usize = 128; -struct DatagramStream { - local_socket: UdpSocket, - peer_address: SocketAddr, - display: String, -} +const SEND_ATTEMPTS: u16 = 5; -impl DatagramStream { - fn new(local_socket: UdpSocket, peer_address: SocketAddr) -> Self { - let local_address = local_socket.local_addr().unwrap(); - let local_ip = local_address.ip().to_string(); - let local_port = local_address.port().to_string(); - let remote_ip = peer_address.ip().to_string(); - let remote_port = peer_address.port().to_string(); - let display = format!("{local_ip}:{local_port} <=> {remote_ip}:{remote_port}"); - Self { - local_socket, - peer_address, - display, - } - } - - pub(super) fn remote_port(&self) -> u16 { - self.peer_address.port() - } - - pub async fn send(&self, buffer: &[u8]) -> std::io::Result<()> { - match self.local_socket.send_to(buffer, self.peer_address).await { - Ok(sent) => { - if sent != buffer.len() { - Err(ErrorKind::ConnectionReset.into()) - } else { - Ok(()) - } +async fn fire_error(error: TFTPError, datagram_stream: &DatagramStream, buffer: &mut [u8]) { + match error.serialize(buffer) { + Ok(to_send) => { + if let Err(send_error) = datagram_stream.send(&buffer[..to_send]).await { + eprintln!("{datagram_stream}: Error sending {error}: {send_error}"); + } else { + eprintln!("{datagram_stream}: Sent {error}"); } - Err(error) => Err(error), } - } - - pub async fn recv(&self, buffer: &mut [u8], min_size: usize) -> std::io::Result { - loop { - match self.local_socket.recv_from(buffer).await { - Ok((recv_size, remote_address)) => { - if remote_address != self.peer_address { - eprintln!( - "{self}: Ignore datagram {recv_size} long from alien {remote_address}" - ); - } else if recv_size < min_size { - eprintln!("{self}: Ignore runt datagram {recv_size} long"); - } else { - return Ok(recv_size); - } - } - Err(error) => return Err(error), - } + Err(buffer_error) => { + eprintln!("{datagram_stream}: Error serializing {error}: {buffer_error}") } } } -impl Debug for DatagramStream { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "<{}>", self.display) - } -} - -impl Display for DatagramStream { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "<{}>", self.display) - } +struct Window { + block_size: u16, + buffers: Vec>, } -pub struct TFTPStream { - udp_stream: DatagramStream, - ack_timeout: AckTimeout, - send_attempts: usize, - recv_buffer: Vec, -} - -impl TFTPStream { - fn new(udp_stream: DatagramStream, ack_timeout: AckTimeout, send_attempts: usize) -> Self { +impl Window { + fn new(block_size: u16, window_size: u16) -> Self { Self { - udp_stream, - ack_timeout, - send_attempts, - recv_buffer: vec![0u8; u16::MAX as usize], + block_size, + buffers: (0..window_size) + .map(|_| vec![0; block_size as usize + 2 * size_of::()]) + .collect(), } } - async fn send_data(&mut self, block: &[u8], block_num: u16) -> Result<(), SendError> { - for attempt in 0..self.send_attempts { - if let Err(err) = self.udp_stream.send(block).await { - return Err(SendError::Network(err.to_string())); - } - match self.read_acknowledge().await { - Ok(block_ack) => { - if block_ack == block_num { - return Ok(()); - } - eprintln!("{self}: Expected acknowledge {block_num}, received {block_ack}"); - } - Err(SendError::Timeout) => { - eprintln!("{self}: Timeout waiting for {block_num}, attempt {attempt}"); - } - Err(send_error) => return Err(send_error), - } - } - Err(SendError::Timeout) + fn size(&self) -> u16 { + self.buffers.capacity() as u16 } - async fn fire_error(&mut self, buffer: &[u8]) { - _ = self.udp_stream.send(buffer).await; + fn push_block( + &mut self, + opened_file: &mut dyn OpenedFile, + index: u16, + ) -> Result<(usize, bool), FileError> { + let buffer = self.buffer(index); + buffer[0] = 0; + buffer[1] = DATA as u8; + buffer[2] = (index >> 8) as u8; + buffer[3] = index as u8; + let read_bytes = opened_file.read_to(&mut buffer[4..])?; + buffer.truncate(read_bytes + 4); + Ok((read_bytes, read_bytes < self.block_size as usize)) + } + fn buffer(&mut self, index: u16) -> &mut Vec { + let window_length = self.buffers.len(); + let buffer = &mut self.buffers[index as usize % window_length]; + unsafe { buffer.set_len(buffer.capacity()) } + buffer } - async fn read_acknowledge(&mut self) -> Result { - let recv_future = self.udp_stream.recv(&mut self.recv_buffer, 4); - if let Ok(read_result) = self.ack_timeout.timeout(recv_future).await { - let _read_size = match read_result { - Ok(size) => size, - Err(err) => return Err(SendError::Network(err.to_string())), - }; - let mut datagram = ReadCursor::new(&self.recv_buffer); - match datagram.extract_ushort() { - Ok(opcode) if opcode == ACK => Ok(datagram - .extract_ushort() - .map_err(|_| SendError::ACKParseError)?), - Ok(opcode) if opcode == ERROR => { - let error_code = datagram - .extract_ushort() - .map_err(|_| SendError::ACKParseError)?; - let error_message = datagram - .extract_string() - .map_err(|_| SendError::ACKParseError)?; - Err(SendError::ClientError(error_code, error_message)) - } - Ok(opcode) => { - eprintln!("{self}: Received unknown opcode 0x{opcode:02x}"); - Err(SendError::ACKParseError) + async fn send(&mut self, index: u16, datagram_stream: &DatagramStream) -> std::io::Result<()> { + let window_length = self.buffers.len(); + let buffer = &mut self.buffers[index as usize % window_length]; + datagram_stream.send(buffer).await + } +} + +async fn send_file( + mut opened_file: Box, + datagram_stream: &DatagramStream, + mut window: Window, + ack_timeout: AckTimeout, + buffer: &mut [u8], +) -> Result<(usize, usize), TFTPError> { + let mut bytes_sent: usize = 0; + let mut blocks_sent: usize = 0; + let mut last_acknowledged_index: u16 = 0; + let mut last_read_index: u16 = 0; + let mut done = false; + while !done { + let unacknowledged_count = last_read_index.wrapping_sub(last_acknowledged_index); + debug_assert!(unacknowledged_count <= window.size()); + let mut to_send = unacknowledged_count; + while to_send < window.size() { + last_read_index = last_read_index.wrapping_add(1); + if let Ok((read_bytes, is_last)) = + window.push_block(opened_file.as_mut(), last_read_index) + { + to_send += 1; + bytes_sent += read_bytes; + if is_last { + done = true; + break; } - Err(_) => Err(SendError::ACKParseError), + } else { + return Err(TFTPError::new("Read file error occurred", UNDEFINED_ERROR)); } - } else { - Err(SendError::Timeout) } + debug_assert!(to_send <= window.size()); + last_acknowledged_index = match send_reliably( + &mut window, + &ack_timeout, + datagram_stream, + buffer, + last_acknowledged_index.wrapping_add(1), + to_send, + ) + .await + { + Ok(received_acknowledged) => received_acknowledged, + Err(SendError::Timeout) => { + return Err(TFTPError::new("Send timeout occurred", UNDEFINED_ERROR)); + } + Err(SendError::ClientError(code, string)) => { + eprintln!("{datagram_stream}: Early termination [{code}] {string}"); + blocks_sent += to_send as usize; + return Ok((bytes_sent, blocks_sent)); + } + Err(_) => { + return Err(TFTPError::new("Unknown error occurred", UNDEFINED_ERROR)); + } + }; } + Ok((bytes_sent, blocks_sent)) } -impl Display for TFTPStream { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "", self.udp_stream, self.ack_timeout) +async fn read_acknowledge( + datagram_stream: &DatagramStream, + buffer: &mut [u8], + ack_timeout: &AckTimeout, +) -> Result { + let recv_future = datagram_stream.recv(buffer, 4); + if let Ok(read_result) = ack_timeout.timeout(recv_future).await { + let _read_size = match read_result { + Ok(size) => size, + Err(err) => { + eprintln!("{datagram_stream}: Read error: {:?}", err); + return Err(RecvError::Network); + } + }; + let mut datagram = ReadCursor::new(buffer); + match datagram.extract_ushort() { + Ok(opcode) if opcode == ACK => { + Ok(datagram.extract_ushort().map_err(|_| RecvError::ACKError)?) + } + Ok(opcode) if opcode == ERROR => { + let error_code = datagram.extract_ushort().map_err(|_| RecvError::ACKError)?; + let error_message = datagram.extract_string().map_err(|_| RecvError::ACKError)?; + Err(RecvError::ClientError(error_code, error_message)) + } + Ok(opcode) => { + eprintln!("{datagram_stream}: Received unknown opcode 0x{opcode:02x}"); + Err(RecvError::ACKError) + } + Err(_) => Err(RecvError::ACKError), + } + } else { + Err(RecvError::Timeout) } } -impl Debug for TFTPStream { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "", self.udp_stream, self.ack_timeout) - } +#[derive(Debug)] +pub(super) enum SendError { + Network, + Timeout, + ClientError(u16, String), + ACKError, } #[derive(Debug)] -pub(super) enum SendError { - Network(String), +pub(super) enum RecvError { + Network, Timeout, ClientError(u16, String), - ACKParseError, + ACKError, } pub(super) struct PeerHandler { @@ -230,6 +243,7 @@ impl PeerHandler { let mut available_roots: Vec> = vec![Box::new(LocalRoot::new(tftp_root.join(peer.to_string())))]; available_roots.extend(get_available_remote_roots(&tftp_root, &peer.to_string())); + available_roots.push(Box::new(LocalRoot::new(tftp_root.join("default")))); eprintln!("{peer}: Available roots: {available_roots:?}"); let runtime = runtime::Builder::new_current_thread() .enable_time() @@ -349,22 +363,14 @@ async fn handle_request( if send_sessions.len() >= send_sessions.capacity() { let error_message = "Maximum sessions per IP exceeded"; let tftp_error = TFTPError::new(error_message, UNDEFINED_ERROR); - if let Ok(to_send) = tftp_error.serialize(&mut send_buffer) - && let Err(error) = datagram_stream.send(&send_buffer[..to_send]).await - { - eprintln!("{datagram_stream}: Error sending {tftp_error}: {error}"); - } + fire_error(tftp_error, &datagram_stream, &mut send_buffer).await; return Err(IrrecoverableError(error_message.to_owned())); }; let mut opened_file = match open_file(&read_request, available_roots) { Ok(file) => file, Err(tftp_error) => { eprintln!("{datagram_stream}: {read_request} denied: {tftp_error}"); - if let Ok(to_send) = tftp_error.serialize(&mut send_buffer) - && let Err(error) = datagram_stream.send(&send_buffer[..to_send]).await - { - eprintln!("{datagram_stream}: Error sending {tftp_error}: {error}"); - } + fire_error(tftp_error, &datagram_stream, &mut send_buffer).await; return Ok(()); } }; @@ -372,17 +378,33 @@ async fn handle_request( send_sessions.insert( datagram_stream.remote_port(), tokio::task::spawn_local(async { - if let Some((mut tftp_stream, block_size)) = negotiate_options( - datagram_stream, - opened_file.as_mut(), + if let Some((window, ack_timeout)) = negotiate_options( + &datagram_stream, + &mut opened_file, &mut send_buffer, read_request.options, ) .await { - send_file(opened_file, &mut tftp_stream, block_size, &mut send_buffer).await; + match send_file( + opened_file, + &datagram_stream, + window, + ack_timeout, + &mut send_buffer, + ) + .await + { + Ok((sent_bytes, sent_blocks)) => eprintln!( + "{datagram_stream}: Sent {sent_bytes} bytes, {sent_blocks} blocks" + ), + Err(tftp_error) => { + fire_error(tftp_error, &datagram_stream, &mut send_buffer).await + } + }; + drop(send_buffer); + drop(datagram_stream); } - drop(send_buffer); }), ); Ok(()) @@ -454,12 +476,105 @@ fn read_json(path: &Path) -> Option { None } +async fn send_reliably( + window: &mut Window, + ack_timeout: &AckTimeout, + datagram_stream: &DatagramStream, + buffer: &mut [u8], + window_index: u16, + count: u16, +) -> Result { + for attempt in 1..=SEND_ATTEMPTS { + for block_index in (0..count).map(|v| window_index.wrapping_add(v)) { + if let Err(send_error) = window.send(block_index, datagram_stream).await { + eprintln!( + "{datagram_stream}: Network error while sending block {block_index}: {send_error}" + ); + return Err(SendError::Network); + } + } + return match read_acknowledge(datagram_stream, buffer, ack_timeout).await { + Ok(received_ack) if received_ack >= window_index => Ok(received_ack), + Ok(unexpected_ack) => { + let tftp_error = TFTPError::new("Received ACK from the past", UNDEFINED_ERROR); + eprintln!( + "{datagram_stream}: Received ACK {unexpected_ack} while expected > {window_index}" + ); + fire_error(tftp_error, datagram_stream, buffer).await; + Err(SendError::ACKError) + } + Err(RecvError::Timeout) => { + let window_end_index = window_index.wrapping_add(count); + eprintln!( + "{datagram_stream}: Timeout waiting for {window_index} .. {window_end_index}, attempt {attempt}" + ); + continue; + } + Err(RecvError::ClientError(error_code, error_message)) => { + Err(SendError::ClientError(error_code, error_message)) + } + Err(_) => Err(SendError::Network), + }; + } + Err(SendError::Timeout) +} + +async fn send_oack_reliably( + oack: &OptionsAcknowledge, + datagram_stream: &DatagramStream, + ack_timeout: &AckTimeout, + buffer: &mut [u8], +) -> io::Result<()> { + let oack_index = 0; + let oack_size = match oack.serialize(buffer) { + Ok(size) => size, + Err(buffer_error) => { + let tftp_error = TFTPError::new("OACK build error", UNDEFINED_ERROR); + fire_error(tftp_error, datagram_stream, buffer).await; + return Err(io::Error::other(format!( + "Error building options: {buffer_error}" + ))); + } + }; + for attempt in 1..=SEND_ATTEMPTS { + datagram_stream.send(&buffer[..oack_size]).await?; + match read_acknowledge(datagram_stream, buffer, ack_timeout).await { + Ok(ack_num) if ack_num == oack_index => return Ok(()), + Ok(ack_num) => { + let tftp_error = TFTPError::new("Unexpected non-zero ACK", UNDEFINED_ERROR); + fire_error(tftp_error, datagram_stream, buffer).await; + return Err(io::Error::other(format!( + "Received unexpected ACK {ack_num} while expecting {oack_index}" + ))); + } + Err(RecvError::Timeout) => { + eprintln!("Timeout waiting for ACK {oack_index}, attempt {attempt}"); + continue; + } + Err(RecvError::ClientError(code, string)) => { + return Err(io::Error::other(format!( + "Early termination while options negotiation [{code}] {string}" + ))); + } + Err(error) => { + return Err(io::Error::other(format!("ACK read error: {:?}", error))); + } + } + } + let tftp_error = TFTPError::new("Send timeout occurred", UNDEFINED_ERROR); + fire_error(tftp_error, datagram_stream, buffer).await; + Err(io::Error::new( + io::ErrorKind::TimedOut, + format!("Timeout waiting for ACK {oack_index}"), + )) +} + async fn negotiate_options( - udp_stream: DatagramStream, - opened_file: &mut dyn OpenedFile, - send_buffer: &mut [u8], + datagram_stream: &DatagramStream, + opened_file: &mut Box, + buffer: &mut [u8], options: HashMap, -) -> Option<(TFTPStream, Blksize)> { +) -> Option<(Window, AckTimeout)> { let mut oack = OptionsAcknowledge::new(); let ack_timeout = { if let Some(timeout) = AckTimeout::find_in(&options) { @@ -477,47 +592,31 @@ async fn negotiate_options( Default::default() } }; - let mut tftp_stream = TFTPStream::new(udp_stream, ack_timeout, 5); if TSize::is_requested(&options) { - match TSize::obtain(opened_file) { + match TSize::obtain(opened_file.as_mut()) { Ok(tsize) => oack.push(tsize.as_key_pair()), Err(err) => { - eprintln!("{tftp_stream}: Can't obtain TSize due to {err:?}") + eprintln!("{datagram_stream}: Can't obtain TSize due to {err:?}") } } }; - if oack.has_options() { - eprintln!("{tftp_stream}: {oack}"); - match oack.serialize(send_buffer) { - Ok((oack_size, block_num)) => { - match tftp_stream - .send_data(&send_buffer[..oack_size], block_num) - .await - { - Ok(_) => {} - Err(SendError::ClientError(code, message)) => { - eprintln!("{tftp_stream}: Client responded with [{code}]: {message}"); - return None; - } - Err(error) => { - eprintln!("{tftp_stream}: Error sending options: {error:?}"); - return None; - } - } - } - Err(buffer_error) => { - eprintln!("{tftp_stream}: Error building options: {buffer_error}"); - let tftp_error = TFTPError::new("OACK build error", UNDEFINED_ERROR); - if let Ok(error_length) = tftp_error.serialize(send_buffer) { - tftp_stream.fire_error(&send_buffer[..error_length]).await; - } else { - eprintln!("{tftp_stream}: Error serializing {buffer_error}"); - } - return None; - } + let window_size = { + if let Some(window_size) = WindowSize::find_in(&options) { + oack.push(window_size.as_key_pair()); + window_size + } else { + Default::default() } }; - Some((tftp_stream, block_size)) + if oack.has_options() + && let Err(oack_negotiation_error) = + send_oack_reliably(&oack, datagram_stream, &ack_timeout, buffer).await + { + eprintln!("{datagram_stream}: {oack_negotiation_error}"); + return None; + }; + let window = Window::new(block_size.get_size() as u16, window_size.get_size() as u16); + Some((window, ack_timeout)) } fn open_file( @@ -538,76 +637,6 @@ fn open_file( Err(TFTPError::new("File not found", FILE_NOT_FOUND)) } -async fn send_file( - mut opened_file: Box, - tftp_stream: &mut TFTPStream, - block_size: Blksize, - send_buffer: &mut [u8], -) { - let mut offset: usize = 0; - let mut block_num: u16 = 1; - loop { - let header_size = place_block_header(send_buffer, block_num); - let chunk_size = - match block_size.read_chunk(opened_file.as_mut(), &mut send_buffer[header_size..]) { - Ok(chunk_size) => chunk_size, - Err(_err) => { - let tftp_error = TFTPError::new("Read error occurred", UNDEFINED_ERROR); - eprintln!("{tftp_stream}: {tftp_error}"); - if let Ok(error_length) = tftp_error.serialize(send_buffer) { - tftp_stream.fire_error(&send_buffer[..error_length]).await; - } else { - eprintln!("{tftp_stream}: Error serializing {tftp_error}"); - } - return; - } - }; - match tftp_stream - .send_data(&send_buffer[..header_size + chunk_size], block_num) - .await - { - Ok(_) => {} - Err(SendError::Network(string)) => { - eprintln!("{tftp_stream}: Network error while sending block {block_num}: {string}"); - return; - } - Err(SendError::Timeout) => { - let tftp_error = - TFTPError::new(format!("Timed out block {block_num}"), UNDEFINED_ERROR); - eprintln!("{tftp_stream}: {tftp_error}"); - if let Ok(error_length) = tftp_error.serialize(send_buffer) { - tftp_stream.fire_error(&send_buffer[..error_length]).await; - } else { - eprintln!("{tftp_stream}: Error serializing {tftp_error}"); - } - return; - } - Err(SendError::ClientError(code, message)) => { - eprintln!("{tftp_stream}: Client error received: [{code}] {message}"); - return; - } - Err(send_error) => { - eprintln!( - "{tftp_stream}: Unknown error while sending block {block_num}: {send_error:?}" - ); - return; - } - } - offset += chunk_size; - if block_size.is_last(chunk_size) { - eprintln!("{tftp_stream}: Sent {offset} bytes"); - return; - } - block_num = block_num.wrapping_add(1); - } -} - -fn place_block_header(buffer: &mut [u8], block_number: u16) -> usize { - let mut datagram = WriteCursor::new(buffer); - _ = datagram.put_ushort(DATA).unwrap(); - datagram.put_ushort(block_number).unwrap() -} - #[derive(Debug)] struct IrrecoverableError(String); diff --git a/src/peer_handler/tests.rs b/src/peer_handler/tests.rs new file mode 100644 index 0000000..8b4ff59 --- /dev/null +++ b/src/peer_handler/tests.rs @@ -0,0 +1,228 @@ +use crate::datagram_stream::DatagramStream; +use crate::fs::{FileError, OpenedFile}; +use crate::options::AckTimeout; +use crate::peer_handler::{ACK, DATA, Window, send_file}; +use std::time::Duration; +use std::{fmt, io}; +use tokio::join; +use tokio::net::UdpSocket; +use tokio::time::timeout; + +fn xorshift64star(index: usize, seed: usize) -> usize { + let mut x = index ^ seed; + x ^= x >> 12; + x ^= x << 25; + x ^= x >> 27; + (x.wrapping_mul(0x2545F4914F6CDD1D)) >> 56 +} + +fn weak_pseudo_random_data(len: usize, seed: usize) -> Vec { + (0..len).map(|i| xorshift64star(i, seed) as u8).collect() +} + +fn generate_data(size: usize) -> Vec { + weak_pseudo_random_data(size, size) +} + +struct VirtualOpenedFile { + buffer: Vec, + offset: usize, +} + +impl VirtualOpenedFile { + fn new(buffer: Vec) -> Self { + Self { buffer, offset: 0 } + } +} + +impl fmt::Display for VirtualOpenedFile { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "VirtualOpenedFile size {} [{}]", + self.buffer.len(), + self.offset + ) + } +} + +impl fmt::Debug for VirtualOpenedFile { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "VirtualOpenedFile size {} [{}]", + self.buffer.len(), + self.offset + ) + } +} + +impl OpenedFile for VirtualOpenedFile { + fn read_to(&mut self, buffer: &mut [u8]) -> Result { + let slice_length = buffer.len().min(self.buffer.len() - self.offset); + eprintln!("{}: {} {}", self, slice_length, buffer.len()); + buffer[..slice_length] + .copy_from_slice(&self.buffer[self.offset..self.offset + slice_length]); + self.offset += slice_length; + Ok(slice_length) + } + + fn get_size(&mut self) -> Result { + Ok(self.buffer.len()) + } +} + +async fn make_streams() -> (DatagramStream, DatagramStream) { + let server_socket = UdpSocket::bind("127.0.0.10:0").await.unwrap(); + let client_socket = UdpSocket::bind("127.0.0.20:0").await.unwrap(); + let server_address = server_socket.local_addr().unwrap(); + let client_address = client_socket.local_addr().unwrap(); + ( + DatagramStream::new(server_socket, client_address), + DatagramStream::new(client_socket, server_address), + ) +} + +#[derive(Debug)] +pub(crate) struct DownloadError(String); + +impl fmt::Display for DownloadError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0.clone()) + } +} + +impl From for DownloadError { + fn from(value: io::Error) -> Self { + DownloadError(value.to_string()) + } +} + +async fn download_stream( + datagram_stream: &DatagramStream, + block_size: u16, + window_size: u16, +) -> Result, DownloadError> { + let mut read_data: Vec = Vec::new(); + let block_header_size = 4; + let expected_message_size = block_size as usize + block_header_size; + let mut buffer = vec![0u8; expected_message_size]; + let mut last_received_block_index: u16 = 0; + let mut done = false; + while !done { + for _ in 0..window_size { + let recv_fut = datagram_stream.recv(&mut buffer, block_header_size); + let received_bytes = match timeout(Duration::from_secs(5), recv_fut).await { + Ok(result) => result?, + Err(_timeout) => return Err(DownloadError("timeout".to_string())), + }; + let opcode = ((buffer[0] as u16) << 8) | buffer[1] as u16; + if opcode != DATA { + return Err(DownloadError("Wrong opcode received: {opcode}".to_string())); + } + last_received_block_index = ((buffer[2] as u16) << 8) | (buffer[3] as u16); + read_data.extend_from_slice(&buffer[block_header_size..received_bytes]); + if received_bytes < expected_message_size { + eprintln!( + "Received {received_bytes}, expected {expected_message_size} bytes. Break" + ); + done = true; + break; + } + } + + buffer[0] = 0; + buffer[1] = ACK as u8; + buffer[2] = (last_received_block_index >> 8) as u8; + buffer[3] = (last_received_block_index & 0xFF) as u8; + datagram_stream.send(&buffer[..block_header_size]).await?; + eprintln!("Sent ACK for {}", last_received_block_index); + } + eprintln!("Done"); + Ok(read_data) +} + +#[tokio::test(flavor = "current_thread")] +async fn send_aligned_data() { + let test_data = generate_data(100); + let opened_file = VirtualOpenedFile::new(test_data.clone()); + let (server_stream, client_stream) = make_streams().await; + let ack_timeout = AckTimeout::default(); + let block_size = 100; + let window_size = 1; + let window = Window::new(block_size, window_size); + let mut buffer = vec![0; 1024]; + let send_coro = send_file( + Box::new(opened_file), + &server_stream, + window, + ack_timeout, + &mut buffer, + ); + let recv_coro = download_stream(&client_stream, block_size, window_size); + let (_send_result, recv_result) = join!(send_coro, recv_coro); + assert_eq!(recv_result.unwrap(), test_data); +} +#[tokio::test(flavor = "current_thread")] +async fn send_unaligned_data() { + let test_data = generate_data(512); + let opened_file = VirtualOpenedFile::new(test_data.clone()); + let (server_stream, client_stream) = make_streams().await; + let ack_timeout = AckTimeout::default(); + let block_size = 100; + let window_size = 1; + let window = Window::new(block_size, window_size); + let mut buffer = vec![0; 1024]; + let send_coro = send_file( + Box::new(opened_file), + &server_stream, + window, + ack_timeout, + &mut buffer, + ); + let recv_coro = download_stream(&client_stream, block_size, window_size); + let (_send_result, recv_result) = join!(send_coro, recv_coro); + assert_eq!(recv_result.unwrap(), test_data); +} +#[tokio::test(flavor = "current_thread")] +async fn send_aligned_data_windowed() { + let test_data = generate_data(100); + let opened_file = VirtualOpenedFile::new(test_data.clone()); + let (server_stream, client_stream) = make_streams().await; + let ack_timeout = AckTimeout::default(); + let block_size = 100; + let window_size = 5; + let window = Window::new(block_size, window_size); + let mut buffer = vec![0; 1024]; + let send_coro = send_file( + Box::new(opened_file), + &server_stream, + window, + ack_timeout, + &mut buffer, + ); + let recv_coro = download_stream(&client_stream, block_size, window_size); + let (_send_result, recv_result) = join!(send_coro, recv_coro); + assert_eq!(recv_result.unwrap(), test_data); +} +#[tokio::test(flavor = "current_thread")] +async fn send_unaligned_data_windowed() { + let test_data = generate_data(512); + let opened_file = VirtualOpenedFile::new(test_data.clone()); + let (server_stream, client_stream) = make_streams().await; + let ack_timeout = AckTimeout::default(); + let block_size = 100; + let window_size = 5; + let window = Window::new(block_size, window_size); + let mut buffer = vec![0; 1024]; + let send_coro = send_file( + Box::new(opened_file), + &server_stream, + window, + ack_timeout, + &mut buffer, + ); + let recv_coro = download_stream(&client_stream, block_size, window_size); + let (_send_result, recv_result) = join!(send_coro, recv_coro); + assert_eq!(recv_result.unwrap(), test_data); +} diff --git a/tests/common/client.rs b/tests/common/client.rs index 3d9208c..9178618 100644 --- a/tests/common/client.rs +++ b/tests/common/client.rs @@ -14,6 +14,8 @@ const _ACK: u16 = 0x04; const _ERR: u16 = 0x05; const _OACK: u16 = 0x06; +const _WINDOW_SIZE: &str = "windowsize"; + #[derive(Debug)] struct _SendError { message: String, @@ -163,7 +165,7 @@ impl DatagramStream { } } - async fn send(&self, buffer: &[u8]) -> io::Result<()> { + pub(crate) async fn send(&self, buffer: &[u8]) -> io::Result<()> { match self.local_socket.send_to(buffer, self.peer_address).await { Ok(sent) => { if sent != buffer.len() { @@ -386,7 +388,7 @@ impl OACK { } pub(crate) struct Block { - datagram_stream: DatagramStream, + pub(crate) datagram_stream: DatagramStream, read_buffer: [u8; _BUFFER_SIZE], write_buffer: [u8; _BUFFER_SIZE], read_bytes: usize, @@ -439,6 +441,41 @@ impl Block { write_bytes: buffer_size, }) } + pub(crate) async fn read_next( + mut self, + read_timeout: usize, + ) -> Result> { + let duration = time::Duration::from_secs(read_timeout as u64); + let read_future = self + .datagram_stream + .recv(&mut self.read_buffer, read_timeout, 4); + match tokio::time::timeout(duration, read_future).await { + Ok(Ok(read_bytes)) => { + let mut read_cursor = ReadCursor::new(&mut self.read_buffer[..read_bytes]); + match read_cursor.extract_ushort() { + Ok(code) if code == _DATA => Ok(Block { + datagram_stream: self.datagram_stream, + read_buffer: self.read_buffer, + write_buffer: self.write_buffer, + read_bytes, + }), + Ok(code) if code == _ERR => { + let error_code = read_cursor.extract_ushort().unwrap(); + let message = read_cursor.extract_string().unwrap(); + Err(TFTPClientError::ClientError(error_code, message)) + } + Ok(_code) => Err(TFTPClientError::UnexpectedData( + self.read_buffer[..read_bytes].to_vec(), + )), + Err(parse_error) => { + Err(TFTPClientError::ParseError(format!("{parse_error:?}"))) + } + } + } + Ok(Err(err)) => Err(TFTPClientError::IO(err)), + Err(_timeout_error) => Err(TFTPClientError::Timeout(self)), + } + } } pub(crate) struct SentACK { @@ -563,6 +600,65 @@ pub(crate) async fn download(client: TFTPClient, file: &str) -> Result, Ok(read_data) } +pub(crate) async fn download_window( + client: TFTPClient, + file: &str, + window_size: u16, +) -> Result, DownloadError> { + let default_timeout: usize = 5; + let default_block_size: usize = 512; + let mut read_data: Vec = Vec::new(); + let options = HashMap::from([(_WINDOW_SIZE.to_string(), window_size.to_string())]); + let sent_request = client + .send_optioned_read_request(file, &options) + .await + .map_err(|error| DownloadError::from(error))?; + let oack = sent_request + .read_oack(default_timeout) + .await + .map_err(|error| DownloadError::from(error))?; + if let Some(window_size_received) = oack.fields().get(_WINDOW_SIZE) { + if let Ok(window_size_received) = window_size_received.parse::() { + if window_size_received != window_size { + return Err(DownloadError(format!( + "Window size mismatch: Received {}, expected {}", + window_size_received, window_size + ))); + } + } else { + return Err(DownloadError(format!( + "Window size not recognized: {}", + window_size_received + ))); + } + } else { + return Err(DownloadError("Window size not set".into())); + } + let mut sent_ack: Option = Some(oack.acknowledge().await?); + let mut last_block: Option = None; + let mut done = false; + while !done { + for _ in 0..(window_size) { + let received_block = { + if let Some(last_block) = last_block.take() { + last_block.read_next(default_timeout).await? + } else { + sent_ack.take().unwrap().read_next(default_timeout).await? + } + }; + let recv_block_len = received_block.data().len(); + read_data.extend(received_block.data()); + last_block = Some(received_block); + done = recv_block_len < default_block_size; + if done { + break; + } + } + sent_ack = Some(last_block.take().unwrap().acknowledge().await?); + } + Ok(read_data) +} + #[derive(Debug)] pub(crate) struct DownloadError(String); diff --git a/tests/test_server.rs b/tests/test_server.rs index 64327cd..28b345d 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -9,7 +9,7 @@ use std::path::PathBuf; use std::{fs, time}; use tokio::net::UdpSocket; -use crate::common::client::{TFTPClientError, download}; +use crate::common::client::{TFTPClientError, download, download_window}; mod common; @@ -135,6 +135,37 @@ async fn attempt_download_nonexisting_file() { ); } +#[tokio::test(flavor = "current_thread")] +async fn attempt_download_file_default() { + let arbitrary_source_ip = "127.0.0.11"; + let server_dir = mk_tmp(attempt_download_file_default); + let data = make_payload(512); + let file_name = "file.txt"; + let file = server_dir.join("default").join(file_name); + _write_file(&file, &data); + let running_server = start_rtftp(server_dir).await; + let client = running_server.open_paired_client(arbitrary_source_ip).await; + let read_data = download(client, &file_name).await.unwrap(); + assert_eq!(read_data, data); +} + +#[tokio::test(flavor = "current_thread")] +async fn attempt_download_file_peer_takes_precendence() { + let arbitrary_source_ip = "127.0.0.11"; + let server_dir = mk_tmp(attempt_download_file_peer_takes_precendence); + let file_name = "file.txt"; + let default_data = make_payload(512); + let default_file = server_dir.join("default").join(file_name); + _write_file(&default_file, &default_data); + let peer_data = make_payload(768); + let peer_file = server_dir.join(arbitrary_source_ip).join(file_name); + _write_file(&peer_file, &peer_data); + let running_server = start_rtftp(server_dir).await; + let client = running_server.open_paired_client(arbitrary_source_ip).await; + let read_data = download(client, &file_name).await.unwrap(); + assert_eq!(read_data, peer_data); +} + #[tokio::test(flavor = "current_thread")] async fn access_violation() { let server_dir = mk_tmp(access_violation); @@ -263,6 +294,7 @@ async fn change_timeout() { (Vec::new(), 0u64), (Vec::new(), 0u64), (Vec::new(), 0u64), + (Vec::new(), 0u64), ]; let local_read_timeout = 2usize; let mut buffer = [0u8; _BUFFER_SIZE]; @@ -321,15 +353,25 @@ async fn change_timeout() { retry_buffers[3].1, 4 ); assert_eq!( - retry_buffers[4].0, b"", + retry_buffers[4].0, b"\x00\x05\x00\x00Send timeout occurred\x00", "5: Received: {:?}, Expected: {:?}", - retry_buffers[4].0, b"" + retry_buffers[4].0, b"\x00\x05\x00\x00Send timeout occurred\x00" ); assert_eq!( - retry_buffers[4].1, 0, + retry_buffers[4].1, 5, "5: Timestamp mismatch. Received: {}, Expected: {}", retry_buffers[4].1, 5 ); + assert_eq!( + retry_buffers[5].0, b"", + "5: Received: {:?}, Expected: {:?}", + retry_buffers[5].0, b"" + ); + assert_eq!( + retry_buffers[5].1, 0, + "5: Timestamp mismatch. Received: {}, Expected: {}", + retry_buffers[5].1, 0 + ); } #[tokio::test(flavor = "current_thread")] @@ -570,3 +612,83 @@ async fn test_download_nbd_file_nonaligned_augmented() { let data = make_payload(4194319); assert_eq!(read_data, data); } + +#[tokio::test(flavor = "current_thread")] +async fn download_local_aligned_file_window() { + let source_ip = "127.0.0.11"; + let server_dir = mk_tmp(download_local_aligned_file_window); + let payload_size = 4096; + let data = make_payload(payload_size); + let file_name = "file.txt"; + let file = server_dir.join(source_ip).join(file_name); + _write_file(&file, &data); + let running_server = start_rtftp(server_dir).await; + let client = running_server.open_paired_client(source_ip).await; + let read_data = download_window(client, file_name, 5).await.unwrap(); + assert_eq!(read_data, data); +} + +#[tokio::test(flavor = "current_thread")] +async fn download_local_unaligned_file_window() { + let source_ip = "127.0.0.11"; + let server_dir = mk_tmp(download_local_unaligned_file_window); + let payload_size = 4096; + let data = make_payload(payload_size); + let file_name = "file.txt"; + let file = server_dir.join(source_ip).join(file_name); + _write_file(&file, &data); + let running_server = start_rtftp(server_dir).await; + let client = running_server.open_paired_client(source_ip).await; + let read_data = download_window(client, file_name, 5).await.unwrap(); + assert_eq!(read_data, data); +} + +#[tokio::test(flavor = "current_thread")] +async fn file_window_partial_ack() { + let source_ip = "127.0.0.11"; + let server_dir = mk_tmp(file_window_partial_ack); + let payload_size = 4096; + let data = make_payload(payload_size); + let file_name = "file.txt"; + let file = server_dir.join(source_ip).join(file_name); + _write_file(&file, &data); + let running_server = start_rtftp(server_dir).await; + let client = running_server.open_paired_client(source_ip).await; + let block_size = 100; + let send_options = HashMap::from([ + ("windowsize".to_string(), 3.to_string()), + ("timeout".to_string(), 1.to_string()), + ("blksize".to_string(), block_size.to_string()), + ]); + let sent_request = client + .send_optioned_read_request(file_name, &send_options) + .await + .unwrap(); + let oack = sent_request.read_oack(5).await.unwrap(); + let sent_ack = oack.acknowledge().await.unwrap(); + let first_block = sent_ack.read_next(2).await.unwrap(); + assert_eq!(first_block.data(), data[..block_size].to_vec()); + let second_block = first_block.read_next(2).await.unwrap(); + assert_eq!( + second_block.data(), + data[block_size..block_size * 2].to_vec() + ); + let third_block = second_block.read_next(2).await.unwrap(); + assert_eq!( + third_block.data(), + data[block_size * 2..block_size * 3].to_vec() + ); + let first_block_acknowledge = b"\x00\x04\x00\x01"; + let datagram_stream = third_block.datagram_stream; + datagram_stream.send(first_block_acknowledge).await.unwrap(); + let mut buffer = [0u8; _BUFFER_SIZE]; + datagram_stream.recv(&mut buffer, 2, 0).await.unwrap(); + let second_block_num = u16::from_be_bytes(buffer[2..4].try_into().unwrap()); + assert_eq!(second_block_num, 2); + datagram_stream.recv(&mut buffer, 2, 0).await.unwrap(); + let third_block_num = u16::from_be_bytes(buffer[2..4].try_into().unwrap()); + assert_eq!(third_block_num, 3); + datagram_stream.recv(&mut buffer, 2, 0).await.unwrap(); + let forth_block_num = u16::from_be_bytes(buffer[2..4].try_into().unwrap()); + assert_eq!(forth_block_num, 4); +}