use clap::Parser; use image::GenericImageView; use ndarray::IntoDimension; use ort::value::Tensor; use show_image::{AsImageView, create_window, event}; use std::error::Error; mod model; mod postprocessing; mod preprocessing; use model::{ModelInfo, create_session, get_model_path}; use postprocessing::{apply_mask, create_side_by_side}; use preprocessing::preprocess_image; #[derive(Parser)] #[command(name = "remove_background")] #[command(about = "Remove background from images using ONNX models", long_about = None)] struct Args { #[arg(help = "URL of the image to download and process")] url: String, #[arg(long, help = "Path to custom ONNX model")] model_path: Option, #[arg(long, help = "Skip model download, fail if not cached")] offline: bool, } #[show_image::main] fn main() -> Result<(), Box> { // Parse command line arguments let args = Args::parse(); println!("=== Background Removal Tool ===\n"); println!("Downloading image from: {}", args.url); // Download the image with a user agent (using blocking client) let client = reqwest::blocking::Client::builder() .user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36") .build()?; let response = client.get(&args.url).send()?; if !response.status().is_success() { return Err(format!("Failed to download image: HTTP {}", response.status()).into()); } let bytes = response.bytes()?; println!("Downloaded {} bytes", bytes.len()); // Load the image let img = image::load_from_memory(&bytes)?; let (img_width, img_height) = img.dimensions(); println!("Loaded image: {}x{}\n", img_width, img_height); // Get model info let model_info = ModelInfo::BRIA; println!("Using model: {}", model_info.name); // Get or download model let model_path = get_model_path(&model_info, args.model_path.as_deref(), args.offline)?; // Create ONNX Runtime session let mut session = create_session(&model_path)?; let input_name = session.inputs()[0].name().to_string(); // Preprocess let input_tensor = preprocess_image(&img, 1024, 1024)?; // Run inference println!("Running inference..."); let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?; // Extract mask output let mask_output = &outputs[0]; let (mask_shape, mask_array) = mask_output.try_extract_tensor::()?; // Returns (shape, &[f32]) // Convert the slice to Array4 let mask_tensor = ndarray::ArrayView::from_shape( ( mask_shape[0] as usize, mask_shape[1] as usize, mask_shape[2] as usize, mask_shape[3] as usize, ) .into_dimension(), mask_array, )? .to_owned(); println!( "Inference complete! Output shape: {:?}", mask_tensor.shape() ); // Apply mask to remove background let result_rgba = apply_mask(&img, mask_tensor)?; // Create side-by-side comparison println!("Creating side-by-side comparison..."); let composite = create_side_by_side(&img, &result_rgba)?; let composite_dynamic = image::DynamicImage::ImageRgba8(composite); // Display the result let (comp_width, comp_height) = composite_dynamic.dimensions(); let window = create_window( "Background Removal - Original (Left) vs Result (Right) - Press ESC to close", Default::default(), )?; window.set_image( "comparison", &composite_dynamic .as_image_view() .map_err(|e| e.to_string())?, )?; println!("\n=== Done! ==="); println!( "Displaying side-by-side comparison ({}x{}):", comp_width, comp_height ); println!(" Left: Original image"); println!(" Right: Background removed (shown on checkered background)"); println!("\nPress ESC to close the window."); // Event loop - wait for ESC key to close for event in window.event_channel()? { if let event::WindowEvent::KeyboardInput(event) = event { if event.input.key_code == Some(event::VirtualKeyCode::Escape) && event.input.state.is_pressed() { println!("ESC pressed, closing..."); break; } } } Ok(()) }