diff --git a/Cargo.lock b/Cargo.lock index 0e7b076..e747387 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,6 +45,15 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + [[package]] name = "aligned" version = "0.4.3" @@ -1776,6 +1785,15 @@ dependencies = [ "libc", ] +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -2021,6 +2039,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -2702,6 +2729,23 @@ dependencies = [ "bitflags 2.11.1", ] +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + [[package]] name = "remove_background" version = "0.1.0" @@ -2714,6 +2758,8 @@ dependencies = [ "reqwest", "sha2", "show-image", + "tracing", + "tracing-subscriber", ] [[package]] @@ -3029,6 +3075,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -3303,6 +3358,15 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + [[package]] name = "tiff" version = "0.11.3" @@ -3473,9 +3537,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "tracing-core" version = "0.1.36" @@ -3483,6 +3559,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", ] [[package]] @@ -3598,6 +3704,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "vcpkg" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index 3985b94..23e626b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,5 @@ reqwest = { version = "0.13", features = ["blocking"] } sha2 = "0.11" show-image = { version = "0.14", features = ["image"] } + tracing = "0.1" + tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } diff --git a/src/main.rs b/src/main.rs index 6de7e6d..beb91fd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,8 @@ use { io::{Cursor, Read, Write, stdin, stdout}, path::PathBuf, }, + tracing::{debug, info}, + tracing_subscriber::{EnvFilter, fmt}, }; #[derive(Subcommand)] @@ -67,8 +69,20 @@ struct Args { #[show_image::main] fn main() -> Result<(), Box> { + fmt() + .with_env_filter(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"))) + .with_writer(std::io::stderr) + .init(); + let args = Args::parse(); + info!( + model = ?args.model, + custom_model = args.model_path.as_deref().unwrap_or(""), + offline = args.offline, + "Starting remove_background" + ); + let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)?; match args.command { @@ -78,8 +92,10 @@ fn main() -> Result<(), Box> { output_file, } => { let img = if let Some(input_file) = input_file { + debug!(path = %input_file.display(), "Reading input image from file"); ImageReader::open(input_file)?.decode()? } else { + debug!("Reading input image from stdin"); let mut bytes = Vec::new(); if stdin().lock().read_to_end(&mut bytes)? == 0 { return Err("No input data provided via stdin".into()); @@ -90,15 +106,9 @@ fn main() -> Result<(), Box> { }; let (img_width, img_height) = img.dimensions(); - println!("Loaded image: {}x{}\n", img_width, img_height); + info!(width = img_width, height = img_height, "Loaded image"); - println!("Running inference..."); let mask = session.predict(&img)?; - println!( - "Inference complete! Mask dimensions: {}x{}", - mask.dimensions().0, - mask.dimensions().1 - ); let result_rgba = apply_mask(&img, &mask)?; if debug { @@ -106,8 +116,10 @@ fn main() -> Result<(), Box> { } if let Some(output_file) = output_file { - result_rgba.save(output_file)?; + result_rgba.save(&output_file)?; + info!(path = %output_file.display(), "Wrote output image to file"); } else { + debug!("Writing output image to stdout (PNG)"); let mut buffer = Cursor::new(Vec::new()); result_rgba.write_to(&mut buffer, image::ImageFormat::Png)?; let mut stdout = stdout().lock(); @@ -119,20 +131,49 @@ fn main() -> Result<(), Box> { output_directory, } => { fs::create_dir_all(&output_directory)?; - let files = fs::read_dir(input_directory)?; - for file in files { - let file = file?; - let path = file.path(); - let img = ImageReader::open(&path)?.decode()?; + info!( + input = %input_directory.display(), + output = %output_directory.display(), + "Starting batch processing" + ); + + let mut processed: usize = 0; + let mut failed: usize = 0; + for entry in fs::read_dir(&input_directory)? { + let entry = entry?; + let path = entry.path(); + let span = tracing::info_span!("batch_item", path = %path.display()); + let _enter = span.enter(); + + let img = match ImageReader::open(&path) { + Ok(reader) => match reader.decode() { + Ok(img) => img, + Err(e) => { + tracing::warn!(error = %e, "Failed to decode image, skipping"); + failed += 1; + continue; + } + }, + Err(e) => { + tracing::warn!(error = %e, "Failed to open image, skipping"); + failed += 1; + continue; + } + }; + let mask = session.predict(&img)?; let result_rgba = apply_mask(&img, &mask)?; let output_path = output_directory.join(path.file_name().unwrap()); result_rgba.save(&output_path)?; - println!("Processed image saved to: {}", output_path.display()); + processed += 1; + info!(output = %output_path.display(), "Processed image saved"); } + + info!(processed, failed, "Batch processing complete"); } } + info!("Shutting down"); Ok(()) } @@ -140,7 +181,7 @@ fn debug_mode( img: &image::DynamicImage, result_rgba: &image::RgbaImage, ) -> Result<(), Box> { - println!("Creating side-by-side comparison..."); + info!("Creating side-by-side comparison"); let composite = create_side_by_side(img, result_rgba)?; let composite_dynamic = image::DynamicImage::ImageRgba8(composite); @@ -156,21 +197,18 @@ fn debug_mode( .map_err(|e| e.to_string())?, )?; - println!("\n=== Done! ==="); - println!( - "Displaying side-by-side comparison ({}x{}):", - comp_width, comp_height + info!( + width = comp_width, + height = comp_height, + "Displaying side-by-side comparison (left: original, right: background removed on checkered bg); press ESC to close" ); - println!(" Left: Original image"); - println!(" Right: Background removed (shown on checkered background)"); - println!("\nPress ESC to close the window."); for event in window.event_channel()? { if let event::WindowEvent::KeyboardInput(event) = event && event.input.key_code == Some(event::VirtualKeyCode::Escape) && event.input.state.is_pressed() { - println!("ESC pressed, closing..."); + debug!("ESC pressed, closing window"); break; } } diff --git a/src/model.rs b/src/model.rs index aa05679..74abef5 100644 --- a/src/model.rs +++ b/src/model.rs @@ -8,9 +8,11 @@ use { std::{ error::Error, fs, - io::{Read, Write}, + io::Read, path::{Path, PathBuf}, + time::Instant, }, + tracing::{debug, info, warn}, }; /// CLI-facing model selector. Concrete session metadata (URL, checksum, @@ -34,7 +36,8 @@ fn get_cache_dir() -> Result> { } fn download_file(url: &str, dest: &Path) -> Result<(), Box> { - println!("Downloading model from {}...", url); + info!(%url, dest = %dest.display(), "Downloading model"); + let start = Instant::now(); let client = reqwest::blocking::Client::builder() .user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36") @@ -46,9 +49,10 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box> { return Err(format!("Failed to download model: HTTP {}", response.status()).into()); } - let mut file = fs::File::create(dest)?; + let mut file = std::io::BufWriter::new(fs::File::create(dest)?); let total_size = response.content_length().unwrap_or(0); let mut downloaded = 0u64; + let mut last_reported_pct: u32 = 0; let mut buffer = vec![0; 8192]; loop { @@ -56,21 +60,24 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box> { if bytes_read == 0 { break; } - file.write_all(&buffer[..bytes_read])?; + std::io::Write::write_all(&mut file, &buffer[..bytes_read])?; downloaded += bytes_read as u64; if total_size > 0 { let progress = (downloaded as f64 / total_size as f64 * 100.0) as u32; - print!("\rDownloading... {}%", progress); - std::io::stdout().flush()?; + // Only emit every 10% to avoid log spam while still giving feedback. + if progress >= last_reported_pct + 10 { + last_reported_pct = progress - (progress % 10); + debug!(percent = last_reported_pct, downloaded, total = total_size, "Download progress"); + } } } - if total_size > 0 { - println!("\rDownload complete! "); - } else { - println!("Download complete! ({} bytes)", downloaded); - } + info!( + bytes = downloaded, + elapsed_ms = start.elapsed().as_millis() as u64, + "Download complete" + ); Ok(()) } @@ -110,7 +117,7 @@ pub fn get_model_path( if !model_path.exists() { return Err(format!("Custom model path does not exist: {}", path).into()); } - println!("Using custom model: {}", path); + info!(%path, "Using custom model"); return Ok(model_path); } @@ -119,15 +126,14 @@ pub fn get_model_path( let model_path = cache_dir.join(&model_filename); if model_path.exists() { - println!("Using cached model: {}", model_path.display()); + info!(path = %model_path.display(), "Using cached model"); if let Some(expected_hash) = sha256 { - print!("Verifying model integrity... "); - std::io::stdout().flush()?; + debug!("Verifying cached model integrity"); if verify_hash(&model_path, expected_hash)? { - println!("OK"); + debug!("Cached model hash OK"); } else { - println!("FAILED"); + warn!(path = %model_path.display(), "Cached model hash verification failed"); return Err( "Model hash verification failed. Try deleting the cached model.".into(), ); @@ -145,14 +151,13 @@ pub fn get_model_path( .into()); } - println!("Model not found in cache, downloading..."); + info!("Model not found in cache, downloading"); download_file(url, &model_path)?; if let Some(expected_hash) = sha256 { - print!("Verifying downloaded model... "); - std::io::stdout().flush()?; + debug!("Verifying downloaded model"); if verify_hash(&model_path, expected_hash)? { - println!("OK"); + debug!("Downloaded model hash OK"); } else { fs::remove_file(&model_path)?; return Err("Downloaded model hash verification failed".into()); @@ -164,7 +169,8 @@ pub fn get_model_path( /// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU). pub fn create_session(model_path: &Path) -> Result> { - println!("Loading model into ONNX Runtime with CUDA backend..."); + info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend"); + let start = Instant::now(); let session = Session::builder()? .with_execution_providers([CUDAExecutionProvider::default().build()])? @@ -172,7 +178,10 @@ pub fn create_session(model_path: &Path) -> Result> { .with_intra_threads(4)? .commit_from_file(model_path)?; - println!("Model loaded successfully with CUDA!"); + info!( + elapsed_ms = start.elapsed().as_millis() as u64, + "Model loaded successfully" + ); Ok(session) } diff --git a/src/postprocessing.rs b/src/postprocessing.rs index 1d9c408..98b646a 100644 --- a/src/postprocessing.rs +++ b/src/postprocessing.rs @@ -1,6 +1,7 @@ use { image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType}, std::error::Error, + tracing::debug, }; /// Compose `original` with `mask` as the alpha channel and return an RGBA image. @@ -8,17 +9,18 @@ use { /// The mask is expected to already be grayscale. If its dimensions differ from /// the original, it is resized with LANCZOS3. pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result> { - println!("Applying mask to remove background..."); - let (orig_width, orig_height) = original.dimensions(); let (mask_width, mask_height) = mask.dimensions(); - println!("Mask dimensions: {}x{}", mask_width, mask_height); - println!("Original dimensions: {}x{}", orig_width, orig_height); - let resized_mask: std::borrow::Cow<'_, GrayImage> = if mask_width != orig_width || mask_height != orig_height { - println!("Resizing mask to match original image..."); + debug!( + mask_width, + mask_height, + orig_width, + orig_height, + "Resizing mask to match original image" + ); std::borrow::Cow::Owned(image::imageops::resize( mask, orig_width, @@ -45,7 +47,7 @@ pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result Result> { let (orig_w, orig_h) = img.dimensions(); + + let preprocess_start = Instant::now(); let input = self.normalize(img)?; + debug!(elapsed_ms = preprocess_start.elapsed().as_millis() as u64, "Preprocessing complete"); let apply_sigmoid = self.apply_sigmoid(); let input_name = self.input_name().to_string(); + debug!(%input_name, "Running ONNX session"); + let infer_start = Instant::now(); let outputs = self .inner() .run(ort::inputs![input_name => Tensor::from_array(input)?])?; + info!( + elapsed_ms = infer_start.elapsed().as_millis() as u64, + "ONNX inference complete" + ); let output = &outputs[0]; let (shape, data) = output.try_extract_tensor::()?; @@ -177,6 +187,11 @@ pub trait Session { } let mask = image::imageops::resize(&mask, orig_w, orig_h, FilterType::Lanczos3); + debug!( + mask_width = orig_w, + mask_height = orig_h, + "Mask resized to original image dimensions" + ); Ok(mask) } } @@ -191,7 +206,7 @@ pub fn init_session( if !path.exists() { return Err(format!("Custom model path does not exist: {}", custom_path).into()); } - println!("Using custom model: {}", custom_path); + info!(path = %custom_path, "Using custom model"); match model { Model::Bria => Box::new(BriaSession::from_model_path(&path)?), Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?), @@ -199,11 +214,11 @@ pub fn init_session( } else { match model { Model::Bria => { - println!("Using model: bria-rmbg"); + info!(model = "bria-rmbg", "Using model"); Box::new(BriaSession::new(offline)?) } Model::BiRefNetLite => { - println!("Using model: birefnet-general-lite"); + info!(model = "birefnet-general-lite", "Using model"); Box::new(BiRefNetLiteSession::new(offline)?) } }