2026-01-21 21:14:29 +00:00
|
|
|
use clap::Parser;
|
|
|
|
|
use image::GenericImageView;
|
2026-01-21 23:56:33 +00:00
|
|
|
use ndarray::IntoDimension;
|
2026-01-21 23:25:11 +00:00
|
|
|
use ort::value::Tensor;
|
2026-01-21 21:14:29 +00:00
|
|
|
use show_image::{AsImageView, create_window, event};
|
|
|
|
|
use std::error::Error;
|
|
|
|
|
|
2026-01-21 23:25:11 +00:00
|
|
|
mod model;
|
|
|
|
|
mod postprocessing;
|
|
|
|
|
mod preprocessing;
|
|
|
|
|
|
2026-01-22 22:20:38 +00:00
|
|
|
use model::{ModelInfo, create_session, get_model_path};
|
2026-01-21 23:25:11 +00:00
|
|
|
use postprocessing::{apply_mask, create_side_by_side};
|
|
|
|
|
use preprocessing::preprocess_image;
|
|
|
|
|
|
2026-01-21 21:14:29 +00:00
|
|
|
#[derive(Parser)]
|
|
|
|
|
#[command(name = "remove_background")]
|
2026-01-21 23:25:11 +00:00
|
|
|
#[command(about = "Remove background from images using ONNX models", long_about = None)]
|
2026-01-21 21:14:29 +00:00
|
|
|
struct Args {
|
2026-01-21 23:25:11 +00:00
|
|
|
#[arg(help = "URL of the image to download and process")]
|
2026-01-21 21:14:29 +00:00
|
|
|
url: String,
|
2026-01-21 23:25:11 +00:00
|
|
|
|
|
|
|
|
#[arg(long, help = "Path to custom ONNX model")]
|
|
|
|
|
model_path: Option<String>,
|
|
|
|
|
|
|
|
|
|
#[arg(long, help = "Skip model download, fail if not cached")]
|
|
|
|
|
offline: bool,
|
2026-01-21 21:14:29 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[show_image::main]
|
|
|
|
|
fn main() -> Result<(), Box<dyn Error>> {
|
|
|
|
|
// Parse command line arguments
|
|
|
|
|
let args = Args::parse();
|
|
|
|
|
|
2026-01-21 23:25:11 +00:00
|
|
|
println!("=== Background Removal Tool ===\n");
|
2026-01-21 21:14:29 +00:00
|
|
|
println!("Downloading image from: {}", args.url);
|
|
|
|
|
|
|
|
|
|
// Download the image with a user agent (using blocking client)
|
|
|
|
|
let client = reqwest::blocking::Client::builder()
|
|
|
|
|
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
|
|
|
|
.build()?;
|
|
|
|
|
let response = client.get(&args.url).send()?;
|
|
|
|
|
|
|
|
|
|
if !response.status().is_success() {
|
|
|
|
|
return Err(format!("Failed to download image: HTTP {}", response.status()).into());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let bytes = response.bytes()?;
|
|
|
|
|
println!("Downloaded {} bytes", bytes.len());
|
|
|
|
|
|
|
|
|
|
// Load the image
|
|
|
|
|
let img = image::load_from_memory(&bytes)?;
|
2026-01-22 20:48:08 +00:00
|
|
|
let (img_width, img_height) = img.dimensions();
|
|
|
|
|
println!("Loaded image: {}x{}\n", img_width, img_height);
|
2026-01-21 23:25:11 +00:00
|
|
|
|
|
|
|
|
// Get model info
|
2026-01-22 22:20:38 +00:00
|
|
|
let model_info = ModelInfo::BRIA;
|
2026-01-21 23:25:11 +00:00
|
|
|
println!("Using model: {}", model_info.name);
|
|
|
|
|
|
|
|
|
|
// Get or download model
|
|
|
|
|
let model_path = get_model_path(&model_info, args.model_path.as_deref(), args.offline)?;
|
|
|
|
|
|
|
|
|
|
// Create ONNX Runtime session
|
|
|
|
|
let mut session = create_session(&model_path)?;
|
2026-01-22 22:20:38 +00:00
|
|
|
let input_name = session.inputs()[0].name().to_string();
|
2026-01-22 00:28:12 +00:00
|
|
|
|
2026-01-22 20:48:08 +00:00
|
|
|
// Preprocess
|
2026-01-22 22:20:38 +00:00
|
|
|
let input_tensor = preprocess_image(&img, 1024, 1024)?;
|
2026-01-21 23:25:11 +00:00
|
|
|
|
|
|
|
|
// Run inference
|
|
|
|
|
println!("Running inference...");
|
2026-01-22 22:20:38 +00:00
|
|
|
let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?;
|
2026-01-21 21:14:29 +00:00
|
|
|
|
2026-01-21 23:25:11 +00:00
|
|
|
// Extract mask output
|
|
|
|
|
let mask_output = &outputs[0];
|
|
|
|
|
let (mask_shape, mask_array) = mask_output.try_extract_tensor::<f32>()?; // Returns (shape, &[f32])
|
|
|
|
|
|
2026-01-21 23:56:33 +00:00
|
|
|
// Convert the slice to Array4
|
|
|
|
|
let mask_tensor = ndarray::ArrayView::from_shape(
|
|
|
|
|
(
|
|
|
|
|
mask_shape[0] as usize,
|
|
|
|
|
mask_shape[1] as usize,
|
|
|
|
|
mask_shape[2] as usize,
|
|
|
|
|
mask_shape[3] as usize,
|
|
|
|
|
)
|
|
|
|
|
.into_dimension(),
|
|
|
|
|
mask_array,
|
|
|
|
|
)?
|
|
|
|
|
.to_owned();
|
2026-01-21 23:25:11 +00:00
|
|
|
|
|
|
|
|
println!(
|
|
|
|
|
"Inference complete! Output shape: {:?}",
|
|
|
|
|
mask_tensor.shape()
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
// Apply mask to remove background
|
|
|
|
|
let result_rgba = apply_mask(&img, mask_tensor)?;
|
|
|
|
|
|
|
|
|
|
// Create side-by-side comparison
|
|
|
|
|
println!("Creating side-by-side comparison...");
|
|
|
|
|
let composite = create_side_by_side(&img, &result_rgba)?;
|
|
|
|
|
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
|
|
|
|
|
|
|
|
|
|
// Display the result
|
|
|
|
|
let (comp_width, comp_height) = composite_dynamic.dimensions();
|
|
|
|
|
let window = create_window(
|
|
|
|
|
"Background Removal - Original (Left) vs Result (Right) - Press ESC to close",
|
|
|
|
|
Default::default(),
|
|
|
|
|
)?;
|
2026-01-21 21:14:29 +00:00
|
|
|
window.set_image(
|
2026-01-21 23:25:11 +00:00
|
|
|
"comparison",
|
|
|
|
|
&composite_dynamic
|
|
|
|
|
.as_image_view()
|
|
|
|
|
.map_err(|e| e.to_string())?,
|
2026-01-21 21:14:29 +00:00
|
|
|
)?;
|
|
|
|
|
|
2026-01-21 23:25:11 +00:00
|
|
|
println!("\n=== Done! ===");
|
|
|
|
|
println!(
|
|
|
|
|
"Displaying side-by-side comparison ({}x{}):",
|
|
|
|
|
comp_width, comp_height
|
|
|
|
|
);
|
|
|
|
|
println!(" Left: Original image");
|
|
|
|
|
println!(" Right: Background removed (shown on checkered background)");
|
2026-01-21 23:56:33 +00:00
|
|
|
println!("\nPress ESC to close the window.");
|
2026-01-21 21:14:29 +00:00
|
|
|
// Event loop - wait for ESC key to close
|
|
|
|
|
for event in window.event_channel()? {
|
|
|
|
|
if let event::WindowEvent::KeyboardInput(event) = event {
|
|
|
|
|
if event.input.key_code == Some(event::VirtualKeyCode::Escape)
|
|
|
|
|
&& event.input.state.is_pressed()
|
|
|
|
|
{
|
|
|
|
|
println!("ESC pressed, closing...");
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|