remove_background/src/main.rs

141 lines
4.4 KiB
Rust
Raw Normal View History

2026-01-21 21:14:29 +00:00
use clap::Parser;
use image::GenericImageView;
2026-01-21 23:56:33 +00:00
use ndarray::IntoDimension;
2026-01-21 23:25:11 +00:00
use ort::value::Tensor;
2026-01-21 21:14:29 +00:00
use show_image::{AsImageView, create_window, event};
use std::error::Error;
2026-01-21 23:25:11 +00:00
mod model;
mod postprocessing;
mod preprocessing;
2026-01-22 20:48:08 +00:00
use model::{ModelInfo, create_session, get_model_path, infer_input_info};
2026-01-21 23:25:11 +00:00
use postprocessing::{apply_mask, create_side_by_side};
use preprocessing::preprocess_image;
2026-01-21 21:14:29 +00:00
#[derive(Parser)]
#[command(name = "remove_background")]
2026-01-21 23:25:11 +00:00
#[command(about = "Remove background from images using ONNX models", long_about = None)]
2026-01-21 21:14:29 +00:00
struct Args {
2026-01-21 23:25:11 +00:00
#[arg(help = "URL of the image to download and process")]
2026-01-21 21:14:29 +00:00
url: String,
2026-01-21 23:25:11 +00:00
#[arg(long, help = "Path to custom ONNX model")]
model_path: Option<String>,
#[arg(long, help = "Skip model download, fail if not cached")]
offline: bool,
2026-01-21 21:14:29 +00:00
}
#[show_image::main]
fn main() -> Result<(), Box<dyn Error>> {
// Parse command line arguments
let args = Args::parse();
2026-01-21 23:25:11 +00:00
println!("=== Background Removal Tool ===\n");
2026-01-21 21:14:29 +00:00
println!("Downloading image from: {}", args.url);
// Download the image with a user agent (using blocking client)
let client = reqwest::blocking::Client::builder()
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
.build()?;
let response = client.get(&args.url).send()?;
if !response.status().is_success() {
return Err(format!("Failed to download image: HTTP {}", response.status()).into());
}
let bytes = response.bytes()?;
println!("Downloaded {} bytes", bytes.len());
// Load the image
let img = image::load_from_memory(&bytes)?;
2026-01-22 20:48:08 +00:00
let (img_width, img_height) = img.dimensions();
println!("Loaded image: {}x{}\n", img_width, img_height);
2026-01-21 23:25:11 +00:00
// Get model info
let model_info = ModelInfo::BIREFNET_LITE;
println!("Using model: {}", model_info.name);
// Get or download model
let model_path = get_model_path(&model_info, args.model_path.as_deref(), args.offline)?;
// Create ONNX Runtime session
let mut session = create_session(&model_path)?;
2026-01-22 20:48:08 +00:00
let input_info = infer_input_info(&session, img_width, img_height)?;
2026-01-22 00:28:12 +00:00
2026-01-22 20:48:08 +00:00
// Preprocess
let input_tensor = preprocess_image(&img, input_info.shape.width, input_info.shape.height)?;
2026-01-21 23:25:11 +00:00
// Run inference
println!("Running inference...");
2026-01-22 20:48:08 +00:00
let outputs =
session.run(ort::inputs![input_info.name => Tensor::from_array(input_tensor)?])?;
2026-01-21 21:14:29 +00:00
2026-01-21 23:25:11 +00:00
// Extract mask output
let mask_output = &outputs[0];
let (mask_shape, mask_array) = mask_output.try_extract_tensor::<f32>()?; // Returns (shape, &[f32])
2026-01-21 23:56:33 +00:00
// Convert the slice to Array4
let mask_tensor = ndarray::ArrayView::from_shape(
(
mask_shape[0] as usize,
mask_shape[1] as usize,
mask_shape[2] as usize,
mask_shape[3] as usize,
)
.into_dimension(),
mask_array,
)?
.to_owned();
2026-01-21 23:25:11 +00:00
println!(
"Inference complete! Output shape: {:?}",
mask_tensor.shape()
);
// Apply mask to remove background
let result_rgba = apply_mask(&img, mask_tensor)?;
// Create side-by-side comparison
println!("Creating side-by-side comparison...");
let composite = create_side_by_side(&img, &result_rgba)?;
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
// Display the result
let (comp_width, comp_height) = composite_dynamic.dimensions();
let window = create_window(
"Background Removal - Original (Left) vs Result (Right) - Press ESC to close",
Default::default(),
)?;
2026-01-21 21:14:29 +00:00
window.set_image(
2026-01-21 23:25:11 +00:00
"comparison",
&composite_dynamic
.as_image_view()
.map_err(|e| e.to_string())?,
2026-01-21 21:14:29 +00:00
)?;
2026-01-21 23:25:11 +00:00
println!("\n=== Done! ===");
println!(
"Displaying side-by-side comparison ({}x{}):",
comp_width, comp_height
);
println!(" Left: Original image");
println!(" Right: Background removed (shown on checkered background)");
2026-01-21 23:56:33 +00:00
println!("\nPress ESC to close the window.");
2026-01-21 21:14:29 +00:00
// Event loop - wait for ESC key to close
for event in window.event_channel()? {
if let event::WindowEvent::KeyboardInput(event) = event {
if event.input.key_code == Some(event::VirtualKeyCode::Escape)
&& event.input.state.is_pressed()
{
println!("ESC pressed, closing...");
break;
}
}
}
Ok(())
}