diff --git a/src/compression_engine/program.rs b/src/compression_engine/program.rs index 08107f6..64de7d6 100644 --- a/src/compression_engine/program.rs +++ b/src/compression_engine/program.rs @@ -6,10 +6,53 @@ use std::fs::File; use std::io::{Read, Write}; use std::os::unix::fs::PermissionsExt; use std::path::PathBuf; -use std::process::{Command, Stdio}; +use std::process::{Child, Command, Stdio}; +use std::sync::Arc; use crate::compression_engine::CompressionEngine; +pub struct ProgramReader { + process: Child, + stdout: Option, +} + +impl Read for ProgramReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.stdout.as_mut().unwrap().read(buf) + } +} + +impl Drop for ProgramReader { + fn drop(&mut self) { + // Ensure the process is waited on to prevent zombie processes + let _ = self.process.wait(); + } +} + +pub struct ProgramWriter { + process: Child, + stdin: Option, +} + +impl Write for ProgramWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.stdin.as_mut().unwrap().write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.stdin.as_mut().unwrap().flush() + } +} + +impl Drop for ProgramWriter { + fn drop(&mut self) { + // Close stdin to signal EOF to the child process + drop(self.stdin.take()); + // Ensure the process is waited on to prevent zombie processes + let _ = self.process.wait(); + } +} + #[derive(Debug, Eq, PartialEq, Clone)] pub struct CompressionEngineProgram { pub program: String, @@ -72,7 +115,7 @@ impl CompressionEngine for CompressionEngineProgram { let file = File::open(file_path).context("Unable to open file for reading")?; - let process = Command::new(program.clone()) + let mut process = Command::new(program.clone()) .args(args.clone()) .stdin(file) .stdout(Stdio::piped()) @@ -82,11 +125,19 @@ impl CompressionEngine for CompressionEngineProgram { program, args ))?; - Ok(Box::new(process.stdout.unwrap())) + + let stdout = process.stdout.take().ok_or_else(|| { + anyhow!("Failed to capture stdout from child process") + })?; + + Ok(Box::new(ProgramReader { + process, + stdout: Some(stdout), + })) } fn create(&self, file_path: PathBuf) -> Result> { - debug!("COMPRESSION: Writting to {:?} using {:?}", file_path, *self); + debug!("COMPRESSION: Writing to {:?} using {:?}", file_path, *self); let program = self.program.clone(); let args = self.compress.clone(); @@ -98,7 +149,7 @@ impl CompressionEngine for CompressionEngineProgram { let file = File::create(file_path).context("Unable to open file for writing")?; - let process = Command::new(program.clone()) + let mut process = Command::new(program.clone()) .args(args.clone()) .stdin(Stdio::piped()) .stdout(file) @@ -109,6 +160,13 @@ impl CompressionEngine for CompressionEngineProgram { args ))?; - Ok(Box::new(process.stdin.unwrap())) + let stdin = process.stdin.take().ok_or_else(|| { + anyhow!("Failed to capture stdin from child process") + })?; + + Ok(Box::new(ProgramWriter { + process, + stdin: Some(stdin), + })) } } diff --git a/src/modes/delete.rs b/src/modes/delete.rs index feadc3c..f87bf50 100644 --- a/src/modes/delete.rs +++ b/src/modes/delete.rs @@ -35,6 +35,11 @@ pub fn mode_delete( debug!("MAIN: Found item {:?}", item); db::delete_item(conn, item)?; + // Validate that item ID is positive to prevent path traversal issues + if *item_id <= 0 { + return Err(anyhow!("Invalid item ID: {}", item_id)); + } + let mut item_path = data_path.clone(); item_path.push(item_id.to_string()); diff --git a/src/modes/diff.rs b/src/modes/diff.rs index ead2ee9..ee0b1df 100644 --- a/src/modes/diff.rs +++ b/src/modes/diff.rs @@ -11,6 +11,24 @@ use nix::unistd::{close, pipe}; use std::io::Read; use std::os::fd::FromRawFd; use std::process::Stdio; +use std::sync::{Arc, Mutex}; + +// RAII guard for file descriptors to ensure they're closed +struct FdGuard { + fd: c_int, +} + +impl FdGuard { + fn new(fd: c_int) -> Self { + Self { fd } + } +} + +impl Drop for FdGuard { + fn drop(&mut self) { + let _ = close(self.fd); + } +} pub fn mode_diff( cmd: &mut Command, @@ -44,6 +62,14 @@ pub fn mode_diff( log::debug!("MAIN: Found item A {:?}", item_a); log::debug!("MAIN: Found item B {:?}", item_b); + let item_a_id = item_a.id.ok_or_else(|| anyhow!("Item A missing ID"))?; + let item_b_id = item_b.id.ok_or_else(|| anyhow!("Item B missing ID"))?; + + // Validate that item IDs are positive to prevent path traversal issues + if item_a_id <= 0 || item_b_id <= 0 { + return Err(anyhow!("Invalid item ID: {} or {}", item_a_id, item_b_id)); + } + let item_a_tags: Vec = crate::db::get_item_tags(conn, &item_a)? .into_iter() .map(|x| x.name) @@ -55,12 +81,12 @@ pub fn mode_diff( .collect(); let mut item_path_a = data_path.clone(); - item_path_a.push(item_a.id.unwrap().to_string()); // id.unwrap() is safe due to ok_or_else + item_path_a.push(item_a_id.to_string()); let compression_type_a = CompressionType::from_str(&item_a.compression)?; log::debug!("MAIN: Item A has compression type {:?}", compression_type_a); let mut item_path_b = data_path.clone(); - item_path_b.push(item_b.id.unwrap().to_string()); + item_path_b.push(item_b_id.to_string()); let compression_type_b = CompressionType::from_str(&item_b.compression)?; log::debug!("MAIN: Item B has compression type {:?}", compression_type_b); @@ -70,10 +96,11 @@ pub fn mode_diff( let (fd_b_read, fd_b_write) = pipe().map_err(|e: NixError| anyhow!("Failed to create pipe B: {}", e))?; - // Set FD_CLOEXEC on write ends. While they are consumed by File::from_raw_fd, - // it's good practice if the raw FDs were to be handled further before that. - // For this specific code, since from_raw_fd takes ownership immediately, this is less critical - // but doesn't hurt. + // Wrap file descriptors in RAII guards + let _fd_a_read_guard = FdGuard::new(fd_a_read); + let _fd_b_read_guard = FdGuard::new(fd_b_read); + + // Set FD_CLOEXEC on write ends nix::fcntl::fcntl( fd_a_write, nix::fcntl::FcntlArg::F_SETFD(FdFlag::FD_CLOEXEC), @@ -92,14 +119,14 @@ pub fn mode_diff( .arg("--label") .arg(format!( "Keep item A: {} {}", - item_a.id.unwrap(), + item_a_id, item_a_tags.join(" ") )) .arg(format!("/dev/fd/{}", fd_a_read)) .arg("--label") .arg(format!( "Keep item B: {} {}", - item_b.id.unwrap(), + item_b_id, item_b_tags.join(" ") )) .arg(format!("/dev/fd/{}", fd_b_read)) @@ -111,8 +138,9 @@ pub fn mode_diff( .spawn() .map_err(|e| anyhow!("Failed to execute diff command: {}", e))?; - close(fd_a_read).map_err(|e| anyhow!("Failed to close fd_a_read in parent: {}", e))?; - close(fd_b_read).map_err(|e| anyhow!("Failed to close fd_b_read in parent: {}", e))?; + // Close read ends in parent process - they're now guarded by FdGuard + drop(_fd_a_read_guard); + drop(_fd_b_read_guard); let mut child_stdout_pipe = child_process .stdout @@ -130,7 +158,7 @@ pub fn mode_diff( item_path: PathBuf, compression_type: CompressionType, pipe_writer_raw: std::fs::File, - ) { + ) -> Result<()> { use std::io::BufWriter; let mut buffered_pipe_writer = BufWriter::new(pipe_writer_raw); let engine = @@ -138,8 +166,9 @@ pub fn mode_diff( log::debug!("THREAD: Sending item to diff"); engine .copy(item_path, &mut buffered_pipe_writer) - .expect("Failed to copy/compress item"); + .map_err(|e| anyhow!("Failed to copy/compress item: {}", e))?; log::debug!("THREAD: Done sending item to diff"); + Ok(()) } // Function to spawn a writer thread for an item @@ -147,10 +176,10 @@ pub fn mode_diff( item_path: PathBuf, compression_type: CompressionType, fd_write: c_int, - ) -> std::thread::JoinHandle<()> { + ) -> std::thread::JoinHandle> { let pipe_writer_raw = unsafe { std::fs::File::from_raw_fd(fd_write) }; std::thread::spawn(move || { - write_item_to_pipe(item_path, compression_type, pipe_writer_raw); + write_item_to_pipe(item_path, compression_type, pipe_writer_raw) }) } @@ -184,25 +213,39 @@ pub fn mode_diff( // Wait for writer threads to complete (meaning all input has been sent to diff) log::debug!("MAIN: Waiting on writer thread for item A"); - if let Err(panic_payload) = writer_thread_a.join() { - // Propagate panic from writer thread - return Err(anyhow!( - "Writer thread for item A (ID: {}) panicked: {:?}", - ids[0], - panic_payload - )); + match writer_thread_a.join() { + Ok(Ok(())) => { + log::debug!("MAIN: Writer thread for item A completed successfully."); + } + Ok(Err(e)) => { + return Err(anyhow!("Writer thread for item A failed: {}", e)); + } + Err(panic_payload) => { + return Err(anyhow!( + "Writer thread for item A (ID: {}) panicked: {:?}", + ids[0], + panic_payload + )); + } } - log::debug!("MAIN: Writer thread for item A completed."); log::debug!("MAIN: Waiting on writer thread for item B"); - if let Err(panic_payload) = writer_thread_b.join() { - return Err(anyhow!( - "Writer thread for item B (ID: {}) panicked: {:?}", - ids[1], - panic_payload - )); + match writer_thread_b.join() { + Ok(Ok(())) => { + log::debug!("MAIN: Writer thread for item B completed successfully."); + } + Ok(Err(e)) => { + return Err(anyhow!("Writer thread for item B failed: {}", e)); + } + Err(panic_payload) => { + return Err(anyhow!( + "Writer thread for item B (ID: {}) panicked: {:?}", + ids[1], + panic_payload + )); + } } - log::debug!("MAIN: Writer thread for item B completed."); + log::debug!("MAIN: Done waiting on input-writer threads."); // Now that all input has been sent and input pipes will be closed by threads exiting, @@ -217,24 +260,19 @@ pub fn mode_diff( ); // Retrieve the captured output from the reader threads. - // .join().unwrap() here will panic if the reader thread itself panicked. - // The inner Result is from the read_to_end operation within the thread. let stdout_capture_result = stdout_reader_thread .join() - .unwrap_or_else(|panic_payload| { - Err(anyhow!( - "Stdout reader thread panicked: {:?}", - panic_payload - )) - })?; + .map_err(|panic_payload| { + anyhow!("Stdout reader thread panicked: {:?}", panic_payload) + })? + .map_err(|e| anyhow!("Failed to read diff stdout: {}", e))?; + let stderr_capture_result = stderr_reader_thread .join() - .unwrap_or_else(|panic_payload| { - Err(anyhow!( - "Stderr reader thread panicked: {:?}", - panic_payload - )) - })?; + .map_err(|panic_payload| { + anyhow!("Stderr reader thread panicked: {:?}", panic_payload) + })? + .map_err(|e| anyhow!("Failed to read diff stderr: {}", e))?; // Handle diff's exit status and output match diff_status.code() { diff --git a/src/modes/get.rs b/src/modes/get.rs index f399a12..c7d6297 100644 --- a/src/modes/get.rs +++ b/src/modes/get.rs @@ -34,8 +34,14 @@ pub fn mode_get( }; if let Some(item) = item_maybe { + let item_id = item.id.ok_or_else(|| anyhow!("Item missing ID"))?; + // Validate that item ID is positive to prevent path traversal issues + if item_id <= 0 { + return Err(anyhow!("Invalid item ID: {}", item_id)); + } + let mut item_path = data_path.clone(); - item_path.push(item.id.unwrap().to_string()); + item_path.push(item_id.to_string()); let compression_type = CompressionType::from_str(&item.compression)?; let compression_engine = get_compression_engine(compression_type)?; diff --git a/src/modes/save.rs b/src/modes/save.rs index 6ae6ab6..f7ee068 100644 --- a/src/modes/save.rs +++ b/src/modes/save.rs @@ -141,8 +141,9 @@ pub fn mode_save( db::store_meta(conn, meta)?; } + let item_id = item.id.ok_or_else(|| anyhow!("Item missing ID"))?; let mut item_path = data_path.clone(); - item_path.push(id.to_string()); + item_path.push(item_id.to_string()); let mut stdin = io::stdin().lock(); let mut stdout = io::stdout().lock(); diff --git a/src/modes/update.rs b/src/modes/update.rs index 48d263e..bea8ebd 100644 --- a/src/modes/update.rs +++ b/src/modes/update.rs @@ -40,9 +40,10 @@ pub fn mode_update( db::set_item_tags(conn, item.clone(), tags)?; } + let item_id = item.id.ok_or_else(|| anyhow!("Item missing ID"))?; let item_path = { let mut path = data_path.clone(); - path.push(item.id.unwrap().to_string()); + path.push(item_id.to_string()); path };