remove_background/src/postprocessing.rs

138 lines
4.6 KiB
Rust
Raw Normal View History

2026-01-21 23:56:33 +00:00
use image::{DynamicImage, GenericImageView, Rgba, RgbaImage, imageops::FilterType};
2026-01-21 23:25:11 +00:00
use ndarray::Array4;
use std::error::Error;
/// Apply mask to original image to remove background
2026-01-21 23:56:33 +00:00
///
2026-01-21 23:25:11 +00:00
/// Steps:
/// 1. Extract mask from output tensor (shape: [1, 1, H, W])
/// 2. Resize mask to match original image dimensions
/// 3. Apply mask as alpha channel to create RGBA image
pub fn apply_mask(
original: &DynamicImage,
mask_tensor: Array4<f32>,
) -> Result<RgbaImage, Box<dyn Error>> {
println!("Applying mask to remove background...");
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
let (orig_width, orig_height) = original.dimensions();
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
// Extract mask dimensions
let mask_shape = mask_tensor.shape();
if mask_shape[0] != 1 || mask_shape[1] != 1 {
2026-01-21 23:56:33 +00:00
return Err(format!("Expected mask shape [1, 1, H, W], got {:?}", mask_shape).into());
2026-01-21 23:25:11 +00:00
}
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
let mask_height = mask_shape[2] as u32;
let mask_width = mask_shape[3] as u32;
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
println!("Mask dimensions: {}x{}", mask_width, mask_height);
println!("Original dimensions: {}x{}", orig_width, orig_height);
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
// Create a grayscale image from the mask
let mut mask_image = image::GrayImage::new(mask_width, mask_height);
for y in 0..mask_height {
for x in 0..mask_width {
let mask_value = mask_tensor[[0, 0, y as usize, x as usize]];
// Clamp and convert to u8
let pixel_value = (mask_value.clamp(0.0, 1.0) * 255.0) as u8;
mask_image.put_pixel(x, y, image::Luma([pixel_value]));
}
}
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
// Resize mask to match original image dimensions if needed
let resized_mask = if mask_width != orig_width || mask_height != orig_height {
println!("Resizing mask to match original image...");
2026-01-21 23:56:33 +00:00
image::imageops::resize(&mask_image, orig_width, orig_height, FilterType::Lanczos3)
2026-01-21 23:25:11 +00:00
} else {
mask_image
};
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
// Convert original to RGBA and apply mask
let rgba_original = original.to_rgba8();
let mut result = RgbaImage::new(orig_width, orig_height);
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
for y in 0..orig_height {
for x in 0..orig_width {
let orig_pixel = rgba_original.get_pixel(x, y);
let mask_pixel = resized_mask.get_pixel(x, y);
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
// Apply mask as alpha channel
result.put_pixel(
x,
y,
Rgba([
orig_pixel[0], // R
orig_pixel[1], // G
orig_pixel[2], // B
mask_pixel[0], // Alpha from mask
]),
);
}
}
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
println!("Background removal complete!");
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
Ok(result)
}
/// Create a side-by-side comparison image
pub fn create_side_by_side(
original: &DynamicImage,
result: &RgbaImage,
) -> Result<RgbaImage, Box<dyn Error>> {
let (width, height) = original.dimensions();
let mut composite = RgbaImage::new(width * 2, height);
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
// Left side: original image
let original_rgba = original.to_rgba8();
for y in 0..height {
for x in 0..width {
composite.put_pixel(x, y, *original_rgba.get_pixel(x, y));
}
}
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
// Right side: result with checkered background for transparency
for y in 0..height {
for x in 0..width {
let result_pixel = result.get_pixel(x, y);
let alpha = result_pixel[3] as f32 / 255.0;
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
// Create checkered background (8x8 squares)
let checker_size = 8;
let is_light = ((x / checker_size) + (y / checker_size)) % 2 == 0;
let bg_color = if is_light { 200 } else { 150 };
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
// Alpha blend with checkered background
let final_r = (result_pixel[0] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
2026-01-21 23:56:33 +00:00
composite.put_pixel(x + width, y, Rgba([final_r, final_g, final_b, 255]));
2026-01-21 23:25:11 +00:00
}
}
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
Ok(composite)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array4;
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
#[test]
fn test_apply_mask_shape() {
let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100));
let mask = Array4::<f32>::ones((1, 1, 100, 100));
let result = apply_mask(&img, mask).unwrap();
assert_eq!(result.dimensions(), (100, 100));
}
2026-01-21 23:56:33 +00:00
2026-01-21 23:25:11 +00:00
#[test]
fn test_side_by_side_dimensions() {
let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100));
let result_img = RgbaImage::new(100, 100);
let composite = create_side_by_side(&img, &result_img).unwrap();
assert_eq!(composite.dimensions(), (200, 100)); // Double width
}
}