This commit is contained in:
Matthew Deville 2026-01-22 00:56:33 +01:00
parent 296a0172f1
commit 86e0dd734d
4 changed files with 85 additions and 80 deletions

View file

@ -1,5 +1,6 @@
use clap::Parser; use clap::Parser;
use image::GenericImageView; use image::GenericImageView;
use ndarray::IntoDimension;
use ort::value::Tensor; use ort::value::Tensor;
use show_image::{AsImageView, create_window, event}; use show_image::{AsImageView, create_window, event};
use std::error::Error; use std::error::Error;
@ -73,7 +74,18 @@ fn main() -> Result<(), Box<dyn Error>> {
let mask_output = &outputs[0]; let mask_output = &outputs[0];
let (mask_shape, mask_array) = mask_output.try_extract_tensor::<f32>()?; // Returns (shape, &[f32]) let (mask_shape, mask_array) = mask_output.try_extract_tensor::<f32>()?; // Returns (shape, &[f32])
/*let mask_tensor = mask_array.into_dimensionality::<ndarray::Ix4>()?.to_owned(); // Convert the slice to Array4
let mask_tensor = ndarray::ArrayView::from_shape(
(
mask_shape[0] as usize,
mask_shape[1] as usize,
mask_shape[2] as usize,
mask_shape[3] as usize,
)
.into_dimension(),
mask_array,
)?
.to_owned();
println!( println!(
"Inference complete! Output shape: {:?}", "Inference complete! Output shape: {:?}",
@ -108,11 +120,7 @@ fn main() -> Result<(), Box<dyn Error>> {
); );
println!(" Left: Original image"); println!(" Left: Original image");
println!(" Right: Background removed (shown on checkered background)"); println!(" Right: Background removed (shown on checkered background)");
println!("\nPress ESC to close the window."); */ println!("\nPress ESC to close the window.");
let window = create_window(
"Background Removal - Original (Left) vs Result (Right) - Press ESC to close",
Default::default(),
)?;
// Event loop - wait for ESC key to close // Event loop - wait for ESC key to close
for event in window.event_channel()? { for event in window.event_channel()? {
if let event::WindowEvent::KeyboardInput(event) = event { if let event::WindowEvent::KeyboardInput(event) = event {

View file

@ -1,5 +1,5 @@
use ort::session::{Session, builder::GraphOptimizationLevel}; use ort::session::{Session, builder::GraphOptimizationLevel};
use sha2::{Sha256, Digest}; use sha2::{Digest, Sha256};
use std::error::Error; use std::error::Error;
use std::fs; use std::fs;
use std::io::Write; use std::io::Write;
@ -17,7 +17,7 @@ impl ModelInfo {
/// BiRefNet General Lite model /// BiRefNet General Lite model
pub const BIREFNET_LITE: ModelInfo = ModelInfo { pub const BIREFNET_LITE: ModelInfo = ModelInfo {
name: "birefnet-general-lite", name: "birefnet-general-lite",
url: "https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/BiRefNet-general-lite.onnx", url: "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png",
sha256: None, // We'll skip verification for now sha256: None, // We'll skip verification for now
input_size: (1024, 1024), input_size: (1024, 1024),
}; };
@ -25,9 +25,11 @@ impl ModelInfo {
/// Get the cache directory for models /// Get the cache directory for models
fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> { fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
let home = std::env::var("HOME") let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"))?;
.or_else(|_| std::env::var("USERPROFILE"))?; let cache_dir = Path::new(&home)
let cache_dir = Path::new(&home).join(".cache").join("remove_background").join("models"); .join(".cache")
.join("remove_background")
.join("models");
fs::create_dir_all(&cache_dir)?; fs::create_dir_all(&cache_dir)?;
Ok(cache_dir) Ok(cache_dir)
} }
@ -35,21 +37,21 @@ fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
/// Download a file from URL to destination /// Download a file from URL to destination
fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> { fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
println!("Downloading model from {}...", url); println!("Downloading model from {}...", url);
let client = reqwest::blocking::Client::builder() let client = reqwest::blocking::Client::builder()
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36") .user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36")
.build()?; .build()?;
let mut response = client.get(url).send()?; let mut response = client.get(url).send()?;
if !response.status().is_success() { if !response.status().is_success() {
return Err(format!("Failed to download model: HTTP {}", response.status()).into()); return Err(format!("Failed to download model: HTTP {}", response.status()).into());
} }
let mut file = fs::File::create(dest)?; let mut file = fs::File::create(dest)?;
let total_size = response.content_length().unwrap_or(0); let total_size = response.content_length().unwrap_or(0);
let mut downloaded = 0u64; let mut downloaded = 0u64;
let mut buffer = vec![0; 8192]; let mut buffer = vec![0; 8192];
loop { loop {
let bytes_read = std::io::Read::read(&mut response, &mut buffer)?; let bytes_read = std::io::Read::read(&mut response, &mut buffer)?;
@ -58,20 +60,20 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
} }
file.write_all(&buffer[..bytes_read])?; file.write_all(&buffer[..bytes_read])?;
downloaded += bytes_read as u64; downloaded += bytes_read as u64;
if total_size > 0 { if total_size > 0 {
let progress = (downloaded as f64 / total_size as f64 * 100.0) as u32; let progress = (downloaded as f64 / total_size as f64 * 100.0) as u32;
print!("\rDownloading... {}%", progress); print!("\rDownloading... {}%", progress);
std::io::stdout().flush()?; std::io::stdout().flush()?;
} }
} }
if total_size > 0 { if total_size > 0 {
println!("\rDownload complete! "); println!("\rDownload complete! ");
} else { } else {
println!("Download complete! ({} bytes)", downloaded); println!("Download complete! ({} bytes)", downloaded);
} }
Ok(()) Ok(())
} }
@ -86,7 +88,11 @@ fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Er
} }
/// Get or download model, return path to model file /// Get or download model, return path to model file
pub fn get_model_path(model_info: &ModelInfo, custom_path: Option<&str>, offline: bool) -> Result<PathBuf, Box<dyn Error>> { pub fn get_model_path(
model_info: &ModelInfo,
custom_path: Option<&str>,
offline: bool,
) -> Result<PathBuf, Box<dyn Error>> {
// If custom path provided, use it // If custom path provided, use it
if let Some(path) = custom_path { if let Some(path) = custom_path {
let model_path = PathBuf::from(path); let model_path = PathBuf::from(path);
@ -96,15 +102,15 @@ pub fn get_model_path(model_info: &ModelInfo, custom_path: Option<&str>, offline
println!("Using custom model: {}", path); println!("Using custom model: {}", path);
return Ok(model_path); return Ok(model_path);
} }
// Check cache // Check cache
let cache_dir = get_cache_dir()?; let cache_dir = get_cache_dir()?;
let model_filename = format!("{}.onnx", model_info.name); let model_filename = format!("{}.onnx", model_info.name);
let model_path = cache_dir.join(&model_filename); let model_path = cache_dir.join(&model_filename);
if model_path.exists() { if model_path.exists() {
println!("Using cached model: {}", model_path.display()); println!("Using cached model: {}", model_path.display());
// Verify hash if provided // Verify hash if provided
if let Some(expected_hash) = model_info.sha256 { if let Some(expected_hash) = model_info.sha256 {
print!("Verifying model integrity... "); print!("Verifying model integrity... ");
@ -113,24 +119,27 @@ pub fn get_model_path(model_info: &ModelInfo, custom_path: Option<&str>, offline
println!("OK"); println!("OK");
} else { } else {
println!("FAILED"); println!("FAILED");
return Err("Model hash verification failed. Try deleting the cached model.".into()); return Err(
"Model hash verification failed. Try deleting the cached model.".into(),
);
} }
} }
return Ok(model_path); return Ok(model_path);
} }
// Download if not in offline mode // Download if not in offline mode
if offline { if offline {
return Err(format!( return Err(format!(
"Model not found in cache and offline mode is enabled. Cache path: {}", "Model not found in cache and offline mode is enabled. Cache path: {}",
model_path.display() model_path.display()
).into()); )
.into());
} }
println!("Model not found in cache, downloading..."); println!("Model not found in cache, downloading...");
download_file(model_info.url, &model_path)?; download_file(model_info.url, &model_path)?;
// Verify after download // Verify after download
if let Some(expected_hash) = model_info.sha256 { if let Some(expected_hash) = model_info.sha256 {
print!("Verifying downloaded model... "); print!("Verifying downloaded model... ");
@ -143,20 +152,20 @@ pub fn get_model_path(model_info: &ModelInfo, custom_path: Option<&str>, offline
return Err("Downloaded model hash verification failed".into()); return Err("Downloaded model hash verification failed".into());
} }
} }
Ok(model_path) Ok(model_path)
} }
/// Create an ONNX Runtime session from model path /// Create an ONNX Runtime session from model path
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> { pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
println!("Loading model into ONNX Runtime..."); println!("Loading model into ONNX Runtime...");
let session = Session::builder()? let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)? .with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)? .with_intra_threads(4)?
.commit_from_file(model_path)?; .commit_from_file(model_path)?;
println!("Model loaded successfully!"); println!("Model loaded successfully!");
Ok(session) Ok(session)
} }

View file

@ -1,9 +1,9 @@
use image::{DynamicImage, RgbaImage, Rgba, GenericImageView, imageops::FilterType}; use image::{DynamicImage, GenericImageView, Rgba, RgbaImage, imageops::FilterType};
use ndarray::Array4; use ndarray::Array4;
use std::error::Error; use std::error::Error;
/// Apply mask to original image to remove background /// Apply mask to original image to remove background
/// ///
/// Steps: /// Steps:
/// 1. Extract mask from output tensor (shape: [1, 1, H, W]) /// 1. Extract mask from output tensor (shape: [1, 1, H, W])
/// 2. Resize mask to match original image dimensions /// 2. Resize mask to match original image dimensions
@ -13,24 +13,21 @@ pub fn apply_mask(
mask_tensor: Array4<f32>, mask_tensor: Array4<f32>,
) -> Result<RgbaImage, Box<dyn Error>> { ) -> Result<RgbaImage, Box<dyn Error>> {
println!("Applying mask to remove background..."); println!("Applying mask to remove background...");
let (orig_width, orig_height) = original.dimensions(); let (orig_width, orig_height) = original.dimensions();
// Extract mask dimensions // Extract mask dimensions
let mask_shape = mask_tensor.shape(); let mask_shape = mask_tensor.shape();
if mask_shape[0] != 1 || mask_shape[1] != 1 { if mask_shape[0] != 1 || mask_shape[1] != 1 {
return Err(format!( return Err(format!("Expected mask shape [1, 1, H, W], got {:?}", mask_shape).into());
"Expected mask shape [1, 1, H, W], got {:?}",
mask_shape
).into());
} }
let mask_height = mask_shape[2] as u32; let mask_height = mask_shape[2] as u32;
let mask_width = mask_shape[3] as u32; let mask_width = mask_shape[3] as u32;
println!("Mask dimensions: {}x{}", mask_width, mask_height); println!("Mask dimensions: {}x{}", mask_width, mask_height);
println!("Original dimensions: {}x{}", orig_width, orig_height); println!("Original dimensions: {}x{}", orig_width, orig_height);
// Create a grayscale image from the mask // Create a grayscale image from the mask
let mut mask_image = image::GrayImage::new(mask_width, mask_height); let mut mask_image = image::GrayImage::new(mask_width, mask_height);
for y in 0..mask_height { for y in 0..mask_height {
@ -41,29 +38,24 @@ pub fn apply_mask(
mask_image.put_pixel(x, y, image::Luma([pixel_value])); mask_image.put_pixel(x, y, image::Luma([pixel_value]));
} }
} }
// Resize mask to match original image dimensions if needed // Resize mask to match original image dimensions if needed
let resized_mask = if mask_width != orig_width || mask_height != orig_height { let resized_mask = if mask_width != orig_width || mask_height != orig_height {
println!("Resizing mask to match original image..."); println!("Resizing mask to match original image...");
image::imageops::resize( image::imageops::resize(&mask_image, orig_width, orig_height, FilterType::Lanczos3)
&mask_image,
orig_width,
orig_height,
FilterType::Lanczos3,
)
} else { } else {
mask_image mask_image
}; };
// Convert original to RGBA and apply mask // Convert original to RGBA and apply mask
let rgba_original = original.to_rgba8(); let rgba_original = original.to_rgba8();
let mut result = RgbaImage::new(orig_width, orig_height); let mut result = RgbaImage::new(orig_width, orig_height);
for y in 0..orig_height { for y in 0..orig_height {
for x in 0..orig_width { for x in 0..orig_width {
let orig_pixel = rgba_original.get_pixel(x, y); let orig_pixel = rgba_original.get_pixel(x, y);
let mask_pixel = resized_mask.get_pixel(x, y); let mask_pixel = resized_mask.get_pixel(x, y);
// Apply mask as alpha channel // Apply mask as alpha channel
result.put_pixel( result.put_pixel(
x, x,
@ -77,9 +69,9 @@ pub fn apply_mask(
); );
} }
} }
println!("Background removal complete!"); println!("Background removal complete!");
Ok(result) Ok(result)
} }
@ -90,7 +82,7 @@ pub fn create_side_by_side(
) -> Result<RgbaImage, Box<dyn Error>> { ) -> Result<RgbaImage, Box<dyn Error>> {
let (width, height) = original.dimensions(); let (width, height) = original.dimensions();
let mut composite = RgbaImage::new(width * 2, height); let mut composite = RgbaImage::new(width * 2, height);
// Left side: original image // Left side: original image
let original_rgba = original.to_rgba8(); let original_rgba = original.to_rgba8();
for y in 0..height { for y in 0..height {
@ -98,31 +90,27 @@ pub fn create_side_by_side(
composite.put_pixel(x, y, *original_rgba.get_pixel(x, y)); composite.put_pixel(x, y, *original_rgba.get_pixel(x, y));
} }
} }
// Right side: result with checkered background for transparency // Right side: result with checkered background for transparency
for y in 0..height { for y in 0..height {
for x in 0..width { for x in 0..width {
let result_pixel = result.get_pixel(x, y); let result_pixel = result.get_pixel(x, y);
let alpha = result_pixel[3] as f32 / 255.0; let alpha = result_pixel[3] as f32 / 255.0;
// Create checkered background (8x8 squares) // Create checkered background (8x8 squares)
let checker_size = 8; let checker_size = 8;
let is_light = ((x / checker_size) + (y / checker_size)) % 2 == 0; let is_light = ((x / checker_size) + (y / checker_size)) % 2 == 0;
let bg_color = if is_light { 200 } else { 150 }; let bg_color = if is_light { 200 } else { 150 };
// Alpha blend with checkered background // Alpha blend with checkered background
let final_r = (result_pixel[0] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; let final_r = (result_pixel[0] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
composite.put_pixel( composite.put_pixel(x + width, y, Rgba([final_r, final_g, final_b, 255]));
x + width,
y,
Rgba([final_r, final_g, final_b, 255]),
);
} }
} }
Ok(composite) Ok(composite)
} }
@ -130,7 +118,7 @@ pub fn create_side_by_side(
mod tests { mod tests {
use super::*; use super::*;
use ndarray::Array4; use ndarray::Array4;
#[test] #[test]
fn test_apply_mask_shape() { fn test_apply_mask_shape() {
let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100)); let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100));
@ -138,7 +126,7 @@ mod tests {
let result = apply_mask(&img, mask).unwrap(); let result = apply_mask(&img, mask).unwrap();
assert_eq!(result.dimensions(), (100, 100)); assert_eq!(result.dimensions(), (100, 100));
} }
#[test] #[test]
fn test_side_by_side_dimensions() { fn test_side_by_side_dimensions() {
let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100)); let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100));

View file

@ -3,7 +3,7 @@ use ndarray::Array4;
use std::error::Error; use std::error::Error;
/// Preprocess an image for the BiRefNet model /// Preprocess an image for the BiRefNet model
/// ///
/// Steps: /// Steps:
/// 1. Resize to target dimensions (1024x1024) /// 1. Resize to target dimensions (1024x1024)
/// 2. Convert from u8 [0, 255] to f32 [0.0, 1.0] /// 2. Convert from u8 [0, 255] to f32 [0.0, 1.0]
@ -15,30 +15,30 @@ pub fn preprocess_image(
target_height: u32, target_height: u32,
) -> Result<Array4<f32>, Box<dyn Error>> { ) -> Result<Array4<f32>, Box<dyn Error>> {
println!("Preprocessing image..."); println!("Preprocessing image...");
// Step 1: Resize image // Step 1: Resize image
let resized = img.resize_exact(target_width, target_height, FilterType::Lanczos3); let resized = img.resize_exact(target_width, target_height, FilterType::Lanczos3);
let rgb_image = resized.to_rgb8(); let rgb_image = resized.to_rgb8();
let (width, height) = rgb_image.dimensions(); let (width, height) = rgb_image.dimensions();
// Step 2: Create ndarray with shape [1, 3, height, width] // Step 2: Create ndarray with shape [1, 3, height, width]
let mut array = Array4::<f32>::zeros((1, 3, height as usize, width as usize)); let mut array = Array4::<f32>::zeros((1, 3, height as usize, width as usize));
// Step 3: Fill the array, converting from HWC to CHW and normalizing // Step 3: Fill the array, converting from HWC to CHW and normalizing
for y in 0..height { for y in 0..height {
for x in 0..width { for x in 0..width {
let pixel = rgb_image.get_pixel(x, y); let pixel = rgb_image.get_pixel(x, y);
// Normalize from [0, 255] to [0.0, 1.0] // Normalize from [0, 255] to [0.0, 1.0]
array[[0, 0, y as usize, x as usize]] = pixel[0] as f32 / 255.0; // R array[[0, 0, y as usize, x as usize]] = pixel[0] as f32 / 255.0; // R
array[[0, 1, y as usize, x as usize]] = pixel[1] as f32 / 255.0; // G array[[0, 1, y as usize, x as usize]] = pixel[1] as f32 / 255.0; // G
array[[0, 2, y as usize, x as usize]] = pixel[2] as f32 / 255.0; // B array[[0, 2, y as usize, x as usize]] = pixel[2] as f32 / 255.0; // B
} }
} }
println!("Preprocessing complete. Tensor shape: {:?}", array.shape()); println!("Preprocessing complete. Tensor shape: {:?}", array.shape());
Ok(array) Ok(array)
} }
@ -46,19 +46,19 @@ pub fn preprocess_image(
mod tests { mod tests {
use super::*; use super::*;
use image::RgbImage; use image::RgbImage;
#[test] #[test]
fn test_preprocess_shape() { fn test_preprocess_shape() {
let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100)); let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100));
let result = preprocess_image(&img, 1024, 1024).unwrap(); let result = preprocess_image(&img, 1024, 1024).unwrap();
assert_eq!(result.shape(), &[1, 3, 1024, 1024]); assert_eq!(result.shape(), &[1, 3, 1024, 1024]);
} }
#[test] #[test]
fn test_preprocess_normalization() { fn test_preprocess_normalization() {
let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100)); let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100));
let result = preprocess_image(&img, 1024, 1024).unwrap(); let result = preprocess_image(&img, 1024, 1024).unwrap();
// Check that all values are in [0.0, 1.0] // Check that all values are in [0.0, 1.0]
for &val in result.iter() { for &val in result.iter() {
assert!(val >= 0.0 && val <= 1.0); assert!(val >= 0.0 && val <= 1.0);