diff --git a/src/main.rs b/src/main.rs index 0530f46..ea8aed4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ use clap::Parser; use image::GenericImageView; +use ndarray::IntoDimension; use ort::value::Tensor; use show_image::{AsImageView, create_window, event}; use std::error::Error; @@ -73,7 +74,18 @@ fn main() -> Result<(), Box> { let mask_output = &outputs[0]; let (mask_shape, mask_array) = mask_output.try_extract_tensor::()?; // Returns (shape, &[f32]) - /*let mask_tensor = mask_array.into_dimensionality::()?.to_owned(); + // Convert the slice to Array4 + let mask_tensor = ndarray::ArrayView::from_shape( + ( + mask_shape[0] as usize, + mask_shape[1] as usize, + mask_shape[2] as usize, + mask_shape[3] as usize, + ) + .into_dimension(), + mask_array, + )? + .to_owned(); println!( "Inference complete! Output shape: {:?}", @@ -108,11 +120,7 @@ fn main() -> Result<(), Box> { ); println!(" Left: Original image"); println!(" Right: Background removed (shown on checkered background)"); - println!("\nPress ESC to close the window."); */ - let window = create_window( - "Background Removal - Original (Left) vs Result (Right) - Press ESC to close", - Default::default(), - )?; + println!("\nPress ESC to close the window."); // Event loop - wait for ESC key to close for event in window.event_channel()? { if let event::WindowEvent::KeyboardInput(event) = event { diff --git a/src/model.rs b/src/model.rs index a922313..43777fa 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,5 +1,5 @@ use ort::session::{Session, builder::GraphOptimizationLevel}; -use sha2::{Sha256, Digest}; +use sha2::{Digest, Sha256}; use std::error::Error; use std::fs; use std::io::Write; @@ -17,7 +17,7 @@ impl ModelInfo { /// BiRefNet General Lite model pub const BIREFNET_LITE: ModelInfo = ModelInfo { name: "birefnet-general-lite", - url: "https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/BiRefNet-general-lite.onnx", + url: "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png", sha256: None, // We'll skip verification for now input_size: (1024, 1024), }; @@ -25,9 +25,11 @@ impl ModelInfo { /// Get the cache directory for models fn get_cache_dir() -> Result> { - let home = std::env::var("HOME") - .or_else(|_| std::env::var("USERPROFILE"))?; - let cache_dir = Path::new(&home).join(".cache").join("remove_background").join("models"); + let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"))?; + let cache_dir = Path::new(&home) + .join(".cache") + .join("remove_background") + .join("models"); fs::create_dir_all(&cache_dir)?; Ok(cache_dir) } @@ -35,21 +37,21 @@ fn get_cache_dir() -> Result> { /// Download a file from URL to destination fn download_file(url: &str, dest: &Path) -> Result<(), Box> { println!("Downloading model from {}...", url); - + let client = reqwest::blocking::Client::builder() .user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36") .build()?; - + let mut response = client.get(url).send()?; - + if !response.status().is_success() { return Err(format!("Failed to download model: HTTP {}", response.status()).into()); } - + let mut file = fs::File::create(dest)?; let total_size = response.content_length().unwrap_or(0); let mut downloaded = 0u64; - + let mut buffer = vec![0; 8192]; loop { let bytes_read = std::io::Read::read(&mut response, &mut buffer)?; @@ -58,20 +60,20 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box> { } file.write_all(&buffer[..bytes_read])?; downloaded += bytes_read as u64; - + if total_size > 0 { let progress = (downloaded as f64 / total_size as f64 * 100.0) as u32; print!("\rDownloading... {}%", progress); std::io::stdout().flush()?; } } - + if total_size > 0 { println!("\rDownload complete! "); } else { println!("Download complete! ({} bytes)", downloaded); } - + Ok(()) } @@ -86,7 +88,11 @@ fn verify_hash(file_path: &Path, expected_hash: &str) -> Result, offline: bool) -> Result> { +pub fn get_model_path( + model_info: &ModelInfo, + custom_path: Option<&str>, + offline: bool, +) -> Result> { // If custom path provided, use it if let Some(path) = custom_path { let model_path = PathBuf::from(path); @@ -96,15 +102,15 @@ pub fn get_model_path(model_info: &ModelInfo, custom_path: Option<&str>, offline println!("Using custom model: {}", path); return Ok(model_path); } - + // Check cache let cache_dir = get_cache_dir()?; let model_filename = format!("{}.onnx", model_info.name); let model_path = cache_dir.join(&model_filename); - + if model_path.exists() { println!("Using cached model: {}", model_path.display()); - + // Verify hash if provided if let Some(expected_hash) = model_info.sha256 { print!("Verifying model integrity... "); @@ -113,24 +119,27 @@ pub fn get_model_path(model_info: &ModelInfo, custom_path: Option<&str>, offline println!("OK"); } else { println!("FAILED"); - return Err("Model hash verification failed. Try deleting the cached model.".into()); + return Err( + "Model hash verification failed. Try deleting the cached model.".into(), + ); } } - + return Ok(model_path); } - + // Download if not in offline mode if offline { return Err(format!( "Model not found in cache and offline mode is enabled. Cache path: {}", model_path.display() - ).into()); + ) + .into()); } - + println!("Model not found in cache, downloading..."); download_file(model_info.url, &model_path)?; - + // Verify after download if let Some(expected_hash) = model_info.sha256 { print!("Verifying downloaded model... "); @@ -143,20 +152,20 @@ pub fn get_model_path(model_info: &ModelInfo, custom_path: Option<&str>, offline return Err("Downloaded model hash verification failed".into()); } } - + Ok(model_path) } /// Create an ONNX Runtime session from model path pub fn create_session(model_path: &Path) -> Result> { println!("Loading model into ONNX Runtime..."); - + let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(4)? .commit_from_file(model_path)?; - + println!("Model loaded successfully!"); - + Ok(session) } diff --git a/src/postprocessing.rs b/src/postprocessing.rs index ed7b11e..8865a3f 100644 --- a/src/postprocessing.rs +++ b/src/postprocessing.rs @@ -1,9 +1,9 @@ -use image::{DynamicImage, RgbaImage, Rgba, GenericImageView, imageops::FilterType}; +use image::{DynamicImage, GenericImageView, Rgba, RgbaImage, imageops::FilterType}; use ndarray::Array4; use std::error::Error; /// Apply mask to original image to remove background -/// +/// /// Steps: /// 1. Extract mask from output tensor (shape: [1, 1, H, W]) /// 2. Resize mask to match original image dimensions @@ -13,24 +13,21 @@ pub fn apply_mask( mask_tensor: Array4, ) -> Result> { println!("Applying mask to remove background..."); - + let (orig_width, orig_height) = original.dimensions(); - + // Extract mask dimensions let mask_shape = mask_tensor.shape(); if mask_shape[0] != 1 || mask_shape[1] != 1 { - return Err(format!( - "Expected mask shape [1, 1, H, W], got {:?}", - mask_shape - ).into()); + return Err(format!("Expected mask shape [1, 1, H, W], got {:?}", mask_shape).into()); } - + let mask_height = mask_shape[2] as u32; let mask_width = mask_shape[3] as u32; - + println!("Mask dimensions: {}x{}", mask_width, mask_height); println!("Original dimensions: {}x{}", orig_width, orig_height); - + // Create a grayscale image from the mask let mut mask_image = image::GrayImage::new(mask_width, mask_height); for y in 0..mask_height { @@ -41,29 +38,24 @@ pub fn apply_mask( mask_image.put_pixel(x, y, image::Luma([pixel_value])); } } - + // Resize mask to match original image dimensions if needed let resized_mask = if mask_width != orig_width || mask_height != orig_height { println!("Resizing mask to match original image..."); - image::imageops::resize( - &mask_image, - orig_width, - orig_height, - FilterType::Lanczos3, - ) + image::imageops::resize(&mask_image, orig_width, orig_height, FilterType::Lanczos3) } else { mask_image }; - + // Convert original to RGBA and apply mask let rgba_original = original.to_rgba8(); let mut result = RgbaImage::new(orig_width, orig_height); - + for y in 0..orig_height { for x in 0..orig_width { let orig_pixel = rgba_original.get_pixel(x, y); let mask_pixel = resized_mask.get_pixel(x, y); - + // Apply mask as alpha channel result.put_pixel( x, @@ -77,9 +69,9 @@ pub fn apply_mask( ); } } - + println!("Background removal complete!"); - + Ok(result) } @@ -90,7 +82,7 @@ pub fn create_side_by_side( ) -> Result> { let (width, height) = original.dimensions(); let mut composite = RgbaImage::new(width * 2, height); - + // Left side: original image let original_rgba = original.to_rgba8(); for y in 0..height { @@ -98,31 +90,27 @@ pub fn create_side_by_side( composite.put_pixel(x, y, *original_rgba.get_pixel(x, y)); } } - + // Right side: result with checkered background for transparency for y in 0..height { for x in 0..width { let result_pixel = result.get_pixel(x, y); let alpha = result_pixel[3] as f32 / 255.0; - + // Create checkered background (8x8 squares) let checker_size = 8; let is_light = ((x / checker_size) + (y / checker_size)) % 2 == 0; let bg_color = if is_light { 200 } else { 150 }; - + // Alpha blend with checkered background let final_r = (result_pixel[0] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; - - composite.put_pixel( - x + width, - y, - Rgba([final_r, final_g, final_b, 255]), - ); + + composite.put_pixel(x + width, y, Rgba([final_r, final_g, final_b, 255])); } } - + Ok(composite) } @@ -130,7 +118,7 @@ pub fn create_side_by_side( mod tests { use super::*; use ndarray::Array4; - + #[test] fn test_apply_mask_shape() { let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100)); @@ -138,7 +126,7 @@ mod tests { let result = apply_mask(&img, mask).unwrap(); assert_eq!(result.dimensions(), (100, 100)); } - + #[test] fn test_side_by_side_dimensions() { let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100)); diff --git a/src/preprocessing.rs b/src/preprocessing.rs index 95d514c..3f7e1cd 100644 --- a/src/preprocessing.rs +++ b/src/preprocessing.rs @@ -3,7 +3,7 @@ use ndarray::Array4; use std::error::Error; /// Preprocess an image for the BiRefNet model -/// +/// /// Steps: /// 1. Resize to target dimensions (1024x1024) /// 2. Convert from u8 [0, 255] to f32 [0.0, 1.0] @@ -15,30 +15,30 @@ pub fn preprocess_image( target_height: u32, ) -> Result, Box> { println!("Preprocessing image..."); - + // Step 1: Resize image let resized = img.resize_exact(target_width, target_height, FilterType::Lanczos3); let rgb_image = resized.to_rgb8(); - + let (width, height) = rgb_image.dimensions(); - + // Step 2: Create ndarray with shape [1, 3, height, width] let mut array = Array4::::zeros((1, 3, height as usize, width as usize)); - + // Step 3: Fill the array, converting from HWC to CHW and normalizing for y in 0..height { for x in 0..width { let pixel = rgb_image.get_pixel(x, y); - + // Normalize from [0, 255] to [0.0, 1.0] array[[0, 0, y as usize, x as usize]] = pixel[0] as f32 / 255.0; // R array[[0, 1, y as usize, x as usize]] = pixel[1] as f32 / 255.0; // G array[[0, 2, y as usize, x as usize]] = pixel[2] as f32 / 255.0; // B } } - + println!("Preprocessing complete. Tensor shape: {:?}", array.shape()); - + Ok(array) } @@ -46,19 +46,19 @@ pub fn preprocess_image( mod tests { use super::*; use image::RgbImage; - + #[test] fn test_preprocess_shape() { let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100)); let result = preprocess_image(&img, 1024, 1024).unwrap(); assert_eq!(result.shape(), &[1, 3, 1024, 1024]); } - + #[test] fn test_preprocess_normalization() { let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100)); let result = preprocess_image(&img, 1024, 1024).unwrap(); - + // Check that all values are in [0.0, 1.0] for &val in result.iter() { assert!(val >= 0.0 && val <= 1.0);