use image::{DynamicImage, RgbaImage, Rgba, GenericImageView, imageops::FilterType}; use ndarray::Array4; use std::error::Error; /// Apply mask to original image to remove background /// /// Steps: /// 1. Extract mask from output tensor (shape: [1, 1, H, W]) /// 2. Resize mask to match original image dimensions /// 3. Apply mask as alpha channel to create RGBA image pub fn apply_mask( original: &DynamicImage, mask_tensor: Array4, ) -> Result> { println!("Applying mask to remove background..."); let (orig_width, orig_height) = original.dimensions(); // Extract mask dimensions let mask_shape = mask_tensor.shape(); if mask_shape[0] != 1 || mask_shape[1] != 1 { return Err(format!( "Expected mask shape [1, 1, H, W], got {:?}", mask_shape ).into()); } let mask_height = mask_shape[2] as u32; let mask_width = mask_shape[3] as u32; println!("Mask dimensions: {}x{}", mask_width, mask_height); println!("Original dimensions: {}x{}", orig_width, orig_height); // Create a grayscale image from the mask let mut mask_image = image::GrayImage::new(mask_width, mask_height); for y in 0..mask_height { for x in 0..mask_width { let mask_value = mask_tensor[[0, 0, y as usize, x as usize]]; // Clamp and convert to u8 let pixel_value = (mask_value.clamp(0.0, 1.0) * 255.0) as u8; mask_image.put_pixel(x, y, image::Luma([pixel_value])); } } // Resize mask to match original image dimensions if needed let resized_mask = if mask_width != orig_width || mask_height != orig_height { println!("Resizing mask to match original image..."); image::imageops::resize( &mask_image, orig_width, orig_height, FilterType::Lanczos3, ) } else { mask_image }; // Convert original to RGBA and apply mask let rgba_original = original.to_rgba8(); let mut result = RgbaImage::new(orig_width, orig_height); for y in 0..orig_height { for x in 0..orig_width { let orig_pixel = rgba_original.get_pixel(x, y); let mask_pixel = resized_mask.get_pixel(x, y); // Apply mask as alpha channel result.put_pixel( x, y, Rgba([ orig_pixel[0], // R orig_pixel[1], // G orig_pixel[2], // B mask_pixel[0], // Alpha from mask ]), ); } } println!("Background removal complete!"); Ok(result) } /// Create a side-by-side comparison image pub fn create_side_by_side( original: &DynamicImage, result: &RgbaImage, ) -> Result> { let (width, height) = original.dimensions(); let mut composite = RgbaImage::new(width * 2, height); // Left side: original image let original_rgba = original.to_rgba8(); for y in 0..height { for x in 0..width { composite.put_pixel(x, y, *original_rgba.get_pixel(x, y)); } } // Right side: result with checkered background for transparency for y in 0..height { for x in 0..width { let result_pixel = result.get_pixel(x, y); let alpha = result_pixel[3] as f32 / 255.0; // Create checkered background (8x8 squares) let checker_size = 8; let is_light = ((x / checker_size) + (y / checker_size)) % 2 == 0; let bg_color = if is_light { 200 } else { 150 }; // Alpha blend with checkered background let final_r = (result_pixel[0] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; composite.put_pixel( x + width, y, Rgba([final_r, final_g, final_b, 255]), ); } } Ok(composite) } #[cfg(test)] mod tests { use super::*; use ndarray::Array4; #[test] fn test_apply_mask_shape() { let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100)); let mask = Array4::::ones((1, 1, 100, 100)); let result = apply_mask(&img, mask).unwrap(); assert_eq!(result.dimensions(), (100, 100)); } #[test] fn test_side_by_side_dimensions() { let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100)); let result_img = RgbaImage::new(100, 100); let composite = create_side_by_side(&img, &result_img).unwrap(); assert_eq!(composite.dimensions(), (200, 100)); // Double width } }