This commit is contained in:
Matthew Deville 2026-04-22 22:30:41 +02:00
parent 8d2fad9106
commit 31d84c99d7
6 changed files with 235 additions and 57 deletions

112
Cargo.lock generated
View file

@ -45,6 +45,15 @@ dependencies = [
"zerocopy",
]
[[package]]
name = "aho-corasick"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301"
dependencies = [
"memchr",
]
[[package]]
name = "aligned"
version = "0.4.3"
@ -1776,6 +1785,15 @@ dependencies = [
"libc",
]
[[package]]
name = "matchers"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
dependencies = [
"regex-automata",
]
[[package]]
name = "matrixmultiply"
version = "0.3.10"
@ -2021,6 +2039,15 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8"
[[package]]
name = "nu-ansi-term"
version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
@ -2702,6 +2729,23 @@ dependencies = [
"bitflags 2.11.1",
]
[[package]]
name = "regex-automata"
version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
[[package]]
name = "remove_background"
version = "0.1.0"
@ -2714,6 +2758,8 @@ dependencies = [
"reqwest",
"sha2",
"show-image",
"tracing",
"tracing-subscriber",
]
[[package]]
@ -3029,6 +3075,15 @@ dependencies = [
"digest",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]]
name = "shlex"
version = "1.3.0"
@ -3303,6 +3358,15 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "thread_local"
version = "1.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185"
dependencies = [
"cfg-if",
]
[[package]]
name = "tiff"
version = "0.11.3"
@ -3473,9 +3537,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
dependencies = [
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "tracing-core"
version = "0.1.36"
@ -3483,6 +3559,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex-automata",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
]
[[package]]
@ -3598,6 +3704,12 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "valuable"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "vcpkg"
version = "0.2.15"

View file

@ -12,3 +12,5 @@
reqwest = { version = "0.13", features = ["blocking"] }
sha2 = "0.11"
show-image = { version = "0.14", features = ["image"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }

View file

@ -13,6 +13,8 @@ use {
io::{Cursor, Read, Write, stdin, stdout},
path::PathBuf,
},
tracing::{debug, info},
tracing_subscriber::{EnvFilter, fmt},
};
#[derive(Subcommand)]
@ -67,8 +69,20 @@ struct Args {
#[show_image::main]
fn main() -> Result<(), Box<dyn Error>> {
fmt()
.with_env_filter(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")))
.with_writer(std::io::stderr)
.init();
let args = Args::parse();
info!(
model = ?args.model,
custom_model = args.model_path.as_deref().unwrap_or("<none>"),
offline = args.offline,
"Starting remove_background"
);
let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)?;
match args.command {
@ -78,8 +92,10 @@ fn main() -> Result<(), Box<dyn Error>> {
output_file,
} => {
let img = if let Some(input_file) = input_file {
debug!(path = %input_file.display(), "Reading input image from file");
ImageReader::open(input_file)?.decode()?
} else {
debug!("Reading input image from stdin");
let mut bytes = Vec::new();
if stdin().lock().read_to_end(&mut bytes)? == 0 {
return Err("No input data provided via stdin".into());
@ -90,15 +106,9 @@ fn main() -> Result<(), Box<dyn Error>> {
};
let (img_width, img_height) = img.dimensions();
println!("Loaded image: {}x{}\n", img_width, img_height);
info!(width = img_width, height = img_height, "Loaded image");
println!("Running inference...");
let mask = session.predict(&img)?;
println!(
"Inference complete! Mask dimensions: {}x{}",
mask.dimensions().0,
mask.dimensions().1
);
let result_rgba = apply_mask(&img, &mask)?;
if debug {
@ -106,8 +116,10 @@ fn main() -> Result<(), Box<dyn Error>> {
}
if let Some(output_file) = output_file {
result_rgba.save(output_file)?;
result_rgba.save(&output_file)?;
info!(path = %output_file.display(), "Wrote output image to file");
} else {
debug!("Writing output image to stdout (PNG)");
let mut buffer = Cursor::new(Vec::new());
result_rgba.write_to(&mut buffer, image::ImageFormat::Png)?;
let mut stdout = stdout().lock();
@ -119,20 +131,49 @@ fn main() -> Result<(), Box<dyn Error>> {
output_directory,
} => {
fs::create_dir_all(&output_directory)?;
let files = fs::read_dir(input_directory)?;
for file in files {
let file = file?;
let path = file.path();
let img = ImageReader::open(&path)?.decode()?;
info!(
input = %input_directory.display(),
output = %output_directory.display(),
"Starting batch processing"
);
let mut processed: usize = 0;
let mut failed: usize = 0;
for entry in fs::read_dir(&input_directory)? {
let entry = entry?;
let path = entry.path();
let span = tracing::info_span!("batch_item", path = %path.display());
let _enter = span.enter();
let img = match ImageReader::open(&path) {
Ok(reader) => match reader.decode() {
Ok(img) => img,
Err(e) => {
tracing::warn!(error = %e, "Failed to decode image, skipping");
failed += 1;
continue;
}
},
Err(e) => {
tracing::warn!(error = %e, "Failed to open image, skipping");
failed += 1;
continue;
}
};
let mask = session.predict(&img)?;
let result_rgba = apply_mask(&img, &mask)?;
let output_path = output_directory.join(path.file_name().unwrap());
result_rgba.save(&output_path)?;
println!("Processed image saved to: {}", output_path.display());
processed += 1;
info!(output = %output_path.display(), "Processed image saved");
}
info!(processed, failed, "Batch processing complete");
}
}
info!("Shutting down");
Ok(())
}
@ -140,7 +181,7 @@ fn debug_mode(
img: &image::DynamicImage,
result_rgba: &image::RgbaImage,
) -> Result<(), Box<dyn Error>> {
println!("Creating side-by-side comparison...");
info!("Creating side-by-side comparison");
let composite = create_side_by_side(img, result_rgba)?;
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
@ -156,21 +197,18 @@ fn debug_mode(
.map_err(|e| e.to_string())?,
)?;
println!("\n=== Done! ===");
println!(
"Displaying side-by-side comparison ({}x{}):",
comp_width, comp_height
info!(
width = comp_width,
height = comp_height,
"Displaying side-by-side comparison (left: original, right: background removed on checkered bg); press ESC to close"
);
println!(" Left: Original image");
println!(" Right: Background removed (shown on checkered background)");
println!("\nPress ESC to close the window.");
for event in window.event_channel()? {
if let event::WindowEvent::KeyboardInput(event) = event
&& event.input.key_code == Some(event::VirtualKeyCode::Escape)
&& event.input.state.is_pressed()
{
println!("ESC pressed, closing...");
debug!("ESC pressed, closing window");
break;
}
}

View file

@ -8,9 +8,11 @@ use {
std::{
error::Error,
fs,
io::{Read, Write},
io::Read,
path::{Path, PathBuf},
time::Instant,
},
tracing::{debug, info, warn},
};
/// CLI-facing model selector. Concrete session metadata (URL, checksum,
@ -34,7 +36,8 @@ fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
}
fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
println!("Downloading model from {}...", url);
info!(%url, dest = %dest.display(), "Downloading model");
let start = Instant::now();
let client = reqwest::blocking::Client::builder()
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36")
@ -46,9 +49,10 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
return Err(format!("Failed to download model: HTTP {}", response.status()).into());
}
let mut file = fs::File::create(dest)?;
let mut file = std::io::BufWriter::new(fs::File::create(dest)?);
let total_size = response.content_length().unwrap_or(0);
let mut downloaded = 0u64;
let mut last_reported_pct: u32 = 0;
let mut buffer = vec![0; 8192];
loop {
@ -56,21 +60,24 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
if bytes_read == 0 {
break;
}
file.write_all(&buffer[..bytes_read])?;
std::io::Write::write_all(&mut file, &buffer[..bytes_read])?;
downloaded += bytes_read as u64;
if total_size > 0 {
let progress = (downloaded as f64 / total_size as f64 * 100.0) as u32;
print!("\rDownloading... {}%", progress);
std::io::stdout().flush()?;
// Only emit every 10% to avoid log spam while still giving feedback.
if progress >= last_reported_pct + 10 {
last_reported_pct = progress - (progress % 10);
debug!(percent = last_reported_pct, downloaded, total = total_size, "Download progress");
}
}
}
if total_size > 0 {
println!("\rDownload complete! ");
} else {
println!("Download complete! ({} bytes)", downloaded);
}
info!(
bytes = downloaded,
elapsed_ms = start.elapsed().as_millis() as u64,
"Download complete"
);
Ok(())
}
@ -110,7 +117,7 @@ pub fn get_model_path(
if !model_path.exists() {
return Err(format!("Custom model path does not exist: {}", path).into());
}
println!("Using custom model: {}", path);
info!(%path, "Using custom model");
return Ok(model_path);
}
@ -119,15 +126,14 @@ pub fn get_model_path(
let model_path = cache_dir.join(&model_filename);
if model_path.exists() {
println!("Using cached model: {}", model_path.display());
info!(path = %model_path.display(), "Using cached model");
if let Some(expected_hash) = sha256 {
print!("Verifying model integrity... ");
std::io::stdout().flush()?;
debug!("Verifying cached model integrity");
if verify_hash(&model_path, expected_hash)? {
println!("OK");
debug!("Cached model hash OK");
} else {
println!("FAILED");
warn!(path = %model_path.display(), "Cached model hash verification failed");
return Err(
"Model hash verification failed. Try deleting the cached model.".into(),
);
@ -145,14 +151,13 @@ pub fn get_model_path(
.into());
}
println!("Model not found in cache, downloading...");
info!("Model not found in cache, downloading");
download_file(url, &model_path)?;
if let Some(expected_hash) = sha256 {
print!("Verifying downloaded model... ");
std::io::stdout().flush()?;
debug!("Verifying downloaded model");
if verify_hash(&model_path, expected_hash)? {
println!("OK");
debug!("Downloaded model hash OK");
} else {
fs::remove_file(&model_path)?;
return Err("Downloaded model hash verification failed".into());
@ -164,7 +169,8 @@ pub fn get_model_path(
/// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU).
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
println!("Loading model into ONNX Runtime with CUDA backend...");
info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend");
let start = Instant::now();
let session = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
@ -172,7 +178,10 @@ pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
.with_intra_threads(4)?
.commit_from_file(model_path)?;
println!("Model loaded successfully with CUDA!");
info!(
elapsed_ms = start.elapsed().as_millis() as u64,
"Model loaded successfully"
);
Ok(session)
}

View file

@ -1,6 +1,7 @@
use {
image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType},
std::error::Error,
tracing::debug,
};
/// Compose `original` with `mask` as the alpha channel and return an RGBA image.
@ -8,17 +9,18 @@ use {
/// The mask is expected to already be grayscale. If its dimensions differ from
/// the original, it is resized with LANCZOS3.
pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage, Box<dyn Error>> {
println!("Applying mask to remove background...");
let (orig_width, orig_height) = original.dimensions();
let (mask_width, mask_height) = mask.dimensions();
println!("Mask dimensions: {}x{}", mask_width, mask_height);
println!("Original dimensions: {}x{}", orig_width, orig_height);
let resized_mask: std::borrow::Cow<'_, GrayImage> =
if mask_width != orig_width || mask_height != orig_height {
println!("Resizing mask to match original image...");
debug!(
mask_width,
mask_height,
orig_width,
orig_height,
"Resizing mask to match original image"
);
std::borrow::Cow::Owned(image::imageops::resize(
mask,
orig_width,
@ -45,7 +47,7 @@ pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage
}
}
println!("Background removal complete!");
debug!("Mask applied to image");
Ok(result)
}

View file

@ -2,7 +2,8 @@ use {
image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType},
ndarray::{Array4, IntoDimension},
ort::{session::Session as OrtSession, value::Tensor},
std::error::Error,
std::{error::Error, time::Instant},
tracing::{debug, info},
};
use crate::model::Model;
@ -105,13 +106,22 @@ pub trait Session {
/// 5. scale to `u8`, resize to original image dimensions
fn predict(&mut self, img: &DynamicImage) -> Result<GrayImage, Box<dyn Error>> {
let (orig_w, orig_h) = img.dimensions();
let preprocess_start = Instant::now();
let input = self.normalize(img)?;
debug!(elapsed_ms = preprocess_start.elapsed().as_millis() as u64, "Preprocessing complete");
let apply_sigmoid = self.apply_sigmoid();
let input_name = self.input_name().to_string();
debug!(%input_name, "Running ONNX session");
let infer_start = Instant::now();
let outputs = self
.inner()
.run(ort::inputs![input_name => Tensor::from_array(input)?])?;
info!(
elapsed_ms = infer_start.elapsed().as_millis() as u64,
"ONNX inference complete"
);
let output = &outputs[0];
let (shape, data) = output.try_extract_tensor::<f32>()?;
@ -177,6 +187,11 @@ pub trait Session {
}
let mask = image::imageops::resize(&mask, orig_w, orig_h, FilterType::Lanczos3);
debug!(
mask_width = orig_w,
mask_height = orig_h,
"Mask resized to original image dimensions"
);
Ok(mask)
}
}
@ -191,7 +206,7 @@ pub fn init_session(
if !path.exists() {
return Err(format!("Custom model path does not exist: {}", custom_path).into());
}
println!("Using custom model: {}", custom_path);
info!(path = %custom_path, "Using custom model");
match model {
Model::Bria => Box::new(BriaSession::from_model_path(&path)?),
Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?),
@ -199,11 +214,11 @@ pub fn init_session(
} else {
match model {
Model::Bria => {
println!("Using model: bria-rmbg");
info!(model = "bria-rmbg", "Using model");
Box::new(BriaSession::new(offline)?)
}
Model::BiRefNetLite => {
println!("Using model: birefnet-general-lite");
info!(model = "birefnet-general-lite", "Using model");
Box::new(BiRefNetLiteSession::new(offline)?)
}
}