diff --git a/aft/src/clients.rs b/aft/src/clients.rs index f628dd2..3b01a11 100644 --- a/aft/src/clients.rs +++ b/aft/src/clients.rs @@ -17,7 +17,8 @@ use aft_crypto::{ use log::{debug, error, info}; use sha2::{Digest, Sha256}; use std::{ - io::{self, Read, Write}, + + io::{self, copy, BufReader, Read, Write}, net::{TcpListener, TcpStream}, }; @@ -66,20 +67,20 @@ where T: AeadInPlace, { fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut read_buf = vec![0; buf.len() + AES_GCM_NONCE_SIZE + AES_GCM_TAG_SIZE]; - let bytes_read = self.0.read(&mut read_buf)?; + let mut read_buf = Vec::with_capacity(buf.len() + AES_GCM_NONCE_SIZE + AES_GCM_TAG_SIZE); - if bytes_read < AES_GCM_NONCE_SIZE { - return Ok(0); - } + let bytes_read = + (&self.0).take((buf.len() + AES_GCM_NONCE_SIZE + AES_GCM_TAG_SIZE) as u64).read_to_end(&mut read_buf)?; - read_buf.truncate(bytes_read); + if bytes_read == 0 { + return Ok(0) + } let (data, nonce) = read_buf.split_at(read_buf.len() - AES_GCM_NONCE_SIZE); let dec_buf = self.1.decrypt(data, nonce).expect("Could not decrypt."); buf[..dec_buf.len()].copy_from_slice(&dec_buf); - Ok(bytes_read - AES_GCM_NONCE_SIZE - AES_GCM_TAG_SIZE) + Ok(buf.len()) } } @@ -206,22 +207,19 @@ where /// /// Returns the file-checksum of the sender's. fn read_write_data(&mut self, file: &mut FileOperations, supposed_len: u64) -> io::Result> { - let mut content = vec![0; MAX_CONTENT_LEN]; - info!("Reading file chunks ..."); - loop { - self.get_mut_writer().read_ext(&mut content)?; - if &content[..SIGNAL_LEN] == Signals::EndFt.as_bytes() { - file.file.set_len(supposed_len)?; - break; - } + let mut reader = BufReader::with_capacity(MAX_CONTENT_LEN, self.get_mut_writer()); + copy(&mut reader, &mut file.file)?; - file.update_checksum(&content); - file.write(&content)?; - } + file.seek_end(MAX_CONTENT_LEN as i64)?; + + let mut checksum = [0; MAX_CHECKSUM_LEN]; + file.read_seek_file(&mut checksum)?; + + file.file.set_len(supposed_len)?; // Returns the sender's checksum - Ok(content[SIGNAL_LEN..MAX_CHECKSUM_LEN + SIGNAL_LEN].to_vec()) + Ok(checksum.to_vec()) } /// Returns true if checksums are equal, false if they're not. @@ -249,7 +247,6 @@ where file.seek_start(0)?; debug!("Computing checksum ..."); - // Until EOF while file.get_index()? != end_pos && file.read_seek_file(&mut content)? != 0 { file.update_checksum(&content); } @@ -301,6 +298,10 @@ where let filename = metadata["metadata"]["filename"].as_str().unwrap_or("null"); let recv_checksum = self.read_write_data(&mut file, sizeb)?; + + info!("Computing checksum ..."); + file.compute_checksum()?; + // If the checksum isn't good if !self.check_checksum(&recv_checksum, &file.checksum()) && get_accept_input("Keep the file? ").expect("Couldn't read answer") != 'y' diff --git a/aft/src/constants.rs b/aft/src/constants.rs index f92c6b0..081300d 100644 --- a/aft/src/constants.rs +++ b/aft/src/constants.rs @@ -14,7 +14,7 @@ pub const MAX_MODIFIED_LEN: usize = 12; pub const MAX_IDENTIFIER_LEN: usize = 10; /// Maximum buffer length that is received from a stream. pub const MAX_METADATA_LEN: usize = MAX_FILENAME_LEN + MAX_TYPE_LEN + MAX_SIZE_LEN + MAX_MODIFIED_LEN + 40 /* 40 = other chars such as { */; -/// Maximum size of a chunk (64KB). +/// Maximum size of a chunk. pub const MAX_CONTENT_LEN: usize = 65536; /// Maximum checksum length (Sha256 length in bytes). pub const MAX_CHECKSUM_LEN: usize = 32; diff --git a/aft/src/main.rs b/aft/src/main.rs index 852da53..bfcd775 100644 --- a/aft/src/main.rs +++ b/aft/src/main.rs @@ -13,7 +13,6 @@ use aft_crypto::{ password_generator::generate_passphrase, }; use config::Config; -use env_logger; use log::{error, info, Level}; use sender::Sender; use std::{env::args as args_fn, io::Write, net::{Ipv4Addr, ToSocketAddrs}}; diff --git a/aft/src/relay.rs b/aft/src/relay.rs index d31b832..ec44890 100644 --- a/aft/src/relay.rs +++ b/aft/src/relay.rs @@ -1,13 +1,14 @@ //! Handling relay functionality. use crate::{ - constants::{CLIENT_RECV, MAX_IDENTIFIER_LEN, RELAY, SIGNAL_LEN}, + + constants::{CLIENT_RECV, MAX_IDENTIFIER_LEN, RELAY, SIGNAL_LEN, MAX_METADATA_LEN}, utils::{bytes_to_string, Signals}, }; use log::{debug, error, info}; use sha2::{Digest, Sha256}; use std::{collections::HashMap, io, sync::Arc}; use tokio::{ - io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt}, + io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, sync::RwLock, }; @@ -32,6 +33,32 @@ macro_rules! error_conn { }; } +// Temporary until https://github.com/tokio-rs/tokio/issues/6454 is implemented, then we could use +// copy_bidirectional again. +async fn pre_send(sender: &mut TcpStream, receiver: &mut TcpStream) -> io::Result<()> { + // Write to the sender the receiver's public key + let mut receiver_pk = [0u8; aft_crypto::exchange::KEY_LENGTH]; + receiver.read_exact(&mut receiver_pk).await?; + sender.write_all(&receiver_pk).await?; + + let first_limit = + aft_crypto::exchange::KEY_LENGTH + + MAX_METADATA_LEN + aft_crypto::data::AES_GCM_NONCE_SIZE + aft_crypto::data::AES_GCM_TAG_SIZE; + + let mut pre_buf = sender.take(first_limit as u64); + // Write to the receiver the sender's public key and the metadata + tokio::io::copy(&mut pre_buf, receiver).await?; + + let second_limit = + 8 + crate::constants::SHA_256_LEN + + 2*(aft_crypto::data::AES_GCM_NONCE_SIZE + aft_crypto::data::AES_GCM_TAG_SIZE); + let mut pre_buf = receiver.take(second_limit as u64); + // Write to the receiver the sender's public key and the metadata + tokio::io::copy(&mut pre_buf, sender).await?; + + Ok(()) +} + async fn handle_sender(sender: &mut TcpStream, clients: MovT, recv_identifier: &str, sen_identifier: &str) -> io::Result { @@ -89,7 +116,11 @@ async fn handle_sender(sender: &mut TcpStream, clients: MovT, re } } - copy_bidirectional(sender, &mut receiver).await?; + // https://github.com/tokio-rs/tokio/issues/6454 + pre_send(sender, &mut receiver).await?; + + let mut sen_buf = tokio::io::BufReader::with_capacity(crate::constants::MAX_CONTENT_LEN, sender); + tokio::io::copy_buf(&mut sen_buf, &mut receiver).await?; Ok(true) } diff --git a/aft/src/sender.rs b/aft/src/sender.rs index 41377a3..ed17149 100644 --- a/aft/src/sender.rs +++ b/aft/src/sender.rs @@ -2,7 +2,7 @@ use crate::{ clients::{BaseSocket, Crypto, SWriter}, constants::{ - CLIENT_SEND, MAX_CHECKSUM_LEN, MAX_CONTENT_LEN, MAX_METADATA_LEN, RELAY, SIGNAL_LEN, + CLIENT_SEND, MAX_CHECKSUM_LEN, MAX_CONTENT_LEN, MAX_METADATA_LEN, RELAY, }, errors::Errors, utils::{ @@ -301,10 +301,6 @@ where bytes_sent_sec += read_size; self.current_pos += read_size as u64; - // It's fine to include the 0's if there are any in `buffer` (only happens on the last - // chunk of the file). - file.update_checksum(&buffer); - self.writer.write_ext(&mut buffer)?; // Progress bar @@ -323,11 +319,13 @@ where } } - println!(); - debug!("Reached EOF"); + debug!("\nReached EOF"); + + debug!("Computing checksum ..."); + file.compute_checksum()?; + debug!("Ending file transfer and writing checksum"); - buffer[..SIGNAL_LEN].copy_from_slice(Signals::EndFt.as_bytes()); - buffer[SIGNAL_LEN..MAX_CHECKSUM_LEN + SIGNAL_LEN].copy_from_slice(&file.checksum()); + buffer[..MAX_CHECKSUM_LEN].copy_from_slice(&file.checksum()); self.writer.write_ext(&mut buffer)?; info!("Finished successfully"); diff --git a/aft/src/utils.rs b/aft/src/utils.rs index e055b4b..8f13fe2 100644 --- a/aft/src/utils.rs +++ b/aft/src/utils.rs @@ -188,6 +188,20 @@ impl FileOperations { self.hasher.clone().finalize().to_vec() } + /// Computes the checksum of the current file content. Note this will reset the cursor. + pub fn compute_checksum(&mut self) -> io::Result<()> { + let mut buffer = [0u8; 1024]; + + self.reset_checksum(); + self.seek_start(0)?; + + while self.file.read(&mut buffer)? != 0 { + self.update_checksum(&buffer); + } + + Ok(()) + } + /// Updates the checksum. pub fn update_checksum(&mut self, buffer: &[u8]) { self.hasher.update(buffer);