single and batch

This commit is contained in:
Matthew Deville 2026-04-22 22:09:13 +02:00
parent 395990e47d
commit 8d2fad9106
6 changed files with 786 additions and 483 deletions

1007
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -4,10 +4,11 @@
version = "0.1.0"
[dependencies]
clap = { version = "4.5", features = ["derive"] }
clap = { version = "4", features = ["derive"] }
hex = "0.4"
image = "0.25"
ndarray = "0.17"
ort = "=2.0.0-rc.11"
ort = "=2.0.0-rc.12"
reqwest = { version = "0.13", features = ["blocking"] }
sha2 = "0.10"
sha2 = "0.11"
show-image = { version = "0.14", features = ["image"] }

View file

@ -1,21 +1,50 @@
use clap::Parser;
use image::GenericImageView;
use show_image::{AsImageView, create_window, event};
use std::error::Error;
use remove_background::{
use {
clap::{Parser, Subcommand},
image::{GenericImageView, ImageReader},
remove_background::{
model::Model,
postprocessing::{apply_mask, create_side_by_side},
sessions::{BiRefNetLiteSession, BriaSession, Session},
sessions::init_session,
},
show_image::{AsImageView, create_window, event},
std::{
error::Error,
fs,
io::{Cursor, Read, Write, stdin, stdout},
path::PathBuf,
},
};
#[derive(Subcommand)]
enum Command {
Single {
#[arg(long, help = "Run in debug mode, show images side by side")]
debug: bool,
#[arg(
value_name = "INPUT_FILE",
help = "File to process, if not provided, read from stdin"
)]
input_file: Option<PathBuf>,
#[arg(
value_name = "OUTPUT_FILE",
help = "File to write result to, if not provided, write to stdout"
)]
output_file: Option<PathBuf>,
},
Batch {
#[arg(value_name = "INPUT_DIRECTORY", help = "Directory to read images from")]
input_directory: PathBuf,
#[arg(value_name = "OUTPUT_DIRECTORY", help = "Directory to write images to")]
output_directory: PathBuf,
},
}
#[derive(Parser)]
#[command(name = "remove_background")]
#[command(about = "Remove background from images using ONNX models", long_about = None)]
struct Args {
#[arg(help = "URL of the image to download and process")]
url: String,
#[arg(
short,
long,
@ -31,54 +60,38 @@ struct Args {
#[arg(long, help = "Skip model download, fail if not cached")]
offline: bool,
#[command(subcommand)]
command: Command,
}
#[show_image::main]
fn main() -> Result<(), Box<dyn Error>> {
let args = Args::parse();
println!("=== Background Removal Tool ===\n");
println!("Downloading image from: {}", args.url);
let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)?;
let client = reqwest::blocking::Client::builder()
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
.build()?;
let response = client.get(&args.url).send()?;
if !response.status().is_success() {
return Err(format!("Failed to download image: HTTP {}", response.status()).into());
match args.command {
Command::Single {
debug,
input_file,
output_file,
} => {
let img = if let Some(input_file) = input_file {
ImageReader::open(input_file)?.decode()?
} else {
let mut bytes = Vec::new();
if stdin().lock().read_to_end(&mut bytes)? == 0 {
return Err("No input data provided via stdin".into());
}
ImageReader::new(Cursor::new(bytes))
.with_guessed_format()?
.decode()?
};
let bytes = response.bytes()?;
println!("Downloaded {} bytes", bytes.len());
let img = image::load_from_memory(&bytes)?;
let (img_width, img_height) = img.dimensions();
println!("Loaded image: {}x{}\n", img_width, img_height);
let mut session: Box<dyn Session> = if let Some(custom_path) = args.model_path.as_deref() {
let path = std::path::PathBuf::from(custom_path);
if !path.exists() {
return Err(format!("Custom model path does not exist: {}", custom_path).into());
}
println!("Using custom model: {}", custom_path);
match args.model {
Model::Bria => Box::new(BriaSession::from_model_path(&path)?),
Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?),
}
} else {
match args.model {
Model::Bria => {
println!("Using model: bria-rmbg");
Box::new(BriaSession::new(args.offline)?)
}
Model::BiRefNetLite => {
println!("Using model: birefnet-general-lite");
Box::new(BiRefNetLiteSession::new(args.offline)?)
}
}
};
println!("Running inference...");
let mask = session.predict(&img)?;
println!(
@ -86,11 +99,49 @@ fn main() -> Result<(), Box<dyn Error>> {
mask.dimensions().0,
mask.dimensions().1
);
let result_rgba = apply_mask(&img, &mask)?;
if debug {
debug_mode(&img, &result_rgba)?;
}
if let Some(output_file) = output_file {
result_rgba.save(output_file)?;
} else {
let mut buffer = Cursor::new(Vec::new());
result_rgba.write_to(&mut buffer, image::ImageFormat::Png)?;
let mut stdout = stdout().lock();
stdout.write_all(&buffer.into_inner())?;
}
}
Command::Batch {
input_directory,
output_directory,
} => {
fs::create_dir_all(&output_directory)?;
let files = fs::read_dir(input_directory)?;
for file in files {
let file = file?;
let path = file.path();
let img = ImageReader::open(&path)?.decode()?;
let mask = session.predict(&img)?;
let result_rgba = apply_mask(&img, &mask)?;
let output_path = output_directory.join(path.file_name().unwrap());
result_rgba.save(&output_path)?;
println!("Processed image saved to: {}", output_path.display());
}
}
}
Ok(())
}
fn debug_mode(
img: &image::DynamicImage,
result_rgba: &image::RgbaImage,
) -> Result<(), Box<dyn Error>> {
println!("Creating side-by-side comparison...");
let composite = create_side_by_side(&img, &result_rgba)?;
let composite = create_side_by_side(img, result_rgba)?;
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
let (comp_width, comp_height) = composite_dynamic.dimensions();

View file

@ -1,11 +1,17 @@
use clap::ValueEnum;
use ort::execution_providers::CUDAExecutionProvider;
use ort::session::{Session, builder::GraphOptimizationLevel};
use sha2::{Digest, Sha256};
use std::error::Error;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use {
clap::ValueEnum,
ort::{
execution_providers::CUDAExecutionProvider,
session::{Session, builder::GraphOptimizationLevel},
},
sha2::{Digest, Sha256},
std::{
error::Error,
fs,
io::{Read, Write},
path::{Path, PathBuf},
},
};
/// CLI-facing model selector. Concrete session metadata (URL, checksum,
/// preprocessing params) lives on the `Session` trait impls in `sessions/`.
@ -70,12 +76,19 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
}
fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Error>> {
const BUF_SIZE: usize = 8192;
let mut file = fs::File::open(file_path)?;
let mut buf = [0; BUF_SIZE];
let mut hasher = Sha256::new();
std::io::copy(&mut file, &mut hasher)?;
let hash = hasher.finalize();
let hash_str = format!("{:x}", hash);
Ok(hash_str == expected_hash)
loop {
let bytes_read = file.read(&mut buf)?;
if bytes_read == 0 {
break;
}
hasher.update(&buf[..bytes_read]);
}
Ok(hex::encode(hasher.finalize()) == expected_hash)
}
/// Resolve a model file to a local path, downloading + verifying if needed.

View file

@ -1,5 +1,7 @@
use image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType};
use std::error::Error;
use {
image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType},
std::error::Error,
};
/// Compose `original` with `mask` as the alpha channel and return an RGBA image.
///

View file

@ -1,7 +1,11 @@
use image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType};
use ndarray::{Array4, IntoDimension};
use ort::{session::Session as OrtSession, value::Tensor};
use std::error::Error;
use {
image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType},
ndarray::{Array4, IntoDimension},
ort::{session::Session as OrtSession, value::Tensor},
std::error::Error,
};
use crate::model::Model;
mod birefnet_lite;
mod bria;
@ -176,3 +180,32 @@ pub trait Session {
Ok(mask)
}
}
pub fn init_session(
custom_model_path: Option<&str>,
model: Model,
offline: bool,
) -> Result<Box<dyn Session>, Box<dyn Error>> {
Ok(if let Some(custom_path) = custom_model_path {
let path = std::path::PathBuf::from(custom_path);
if !path.exists() {
return Err(format!("Custom model path does not exist: {}", custom_path).into());
}
println!("Using custom model: {}", custom_path);
match model {
Model::Bria => Box::new(BriaSession::from_model_path(&path)?),
Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?),
}
} else {
match model {
Model::Bria => {
println!("Using model: bria-rmbg");
Box::new(BriaSession::new(offline)?)
}
Model::BiRefNetLite => {
println!("Using model: birefnet-general-lite");
Box::new(BiRefNetLiteSession::new(offline)?)
}
}
})
}