diff --git a/Cargo.lock b/Cargo.lock index e747387..445c477 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2750,6 +2750,7 @@ checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" name = "remove_background" version = "0.1.0" dependencies = [ + "anyhow", "clap", "hex", "image", @@ -2758,6 +2759,7 @@ dependencies = [ "reqwest", "sha2", "show-image", + "thiserror 2.0.18", "tracing", "tracing-subscriber", ] @@ -2870,9 +2872,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.38" +version = "0.23.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69f9466fb2c14ea04357e91413efb882e2a6d4a406e625449bc0a5d360d53a21" +checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e" dependencies = [ "aws-lc-rs", "once_cell", diff --git a/Cargo.toml b/Cargo.toml index 23e626b..55669c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" [dependencies] + anyhow = "1" clap = { version = "4", features = ["derive"] } hex = "0.4" image = "0.25" @@ -12,5 +13,6 @@ reqwest = { version = "0.13", features = ["blocking"] } sha2 = "0.11" show-image = { version = "0.14", features = ["image"] } + thiserror = "2" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..e5e7cc4 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,48 @@ +use {std::path::PathBuf, thiserror::Error}; + +pub type AppResult = Result; + +#[derive(Debug, Error)] +pub enum AppError { + #[error(transparent)] + Io(#[from] std::io::Error), + + #[error(transparent)] + EnvVar(#[from] std::env::VarError), + + #[error(transparent)] + Image(#[from] image::ImageError), + + #[error(transparent)] + Reqwest(#[from] reqwest::Error), + + #[error(transparent)] + Ort(#[from] ort::Error), + + #[error("failed to configure ONNX session builder: {message}")] + OrtSessionBuilder { message: String }, + + #[error(transparent)] + NdarrayShape(#[from] ndarray::ShapeError), + + #[error("no input data provided via stdin")] + NoStdinInput, + + #[error("download failed: HTTP {status}")] + DownloadHttpStatus { status: reqwest::StatusCode }, + + #[error("custom model path does not exist: {path}")] + CustomModelPathMissing { path: PathBuf }, + + #[error("model hash verification failed: {path}")] + ModelHashVerificationFailed { path: PathBuf }, + + #[error("model not found in cache and offline mode is enabled: {cache_path}")] + OfflineModelMissing { cache_path: PathBuf }, + + #[error("expected 4D output tensor [N, C, H, W], got shape {shape:?}")] + UnexpectedTensorShape { shape: Vec }, + + #[error("expected batch size 1, got {batch_size}")] + UnexpectedBatchSize { batch_size: usize }, +} diff --git a/src/lib.rs b/src/lib.rs index 9971b51..cabc53a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod error; pub mod model; pub mod postprocessing; pub mod sessions; diff --git a/src/main.rs b/src/main.rs index 5787363..a55a3b8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,15 @@ use { + anyhow::{Context, Result}, clap::{Parser, Subcommand}, image::{GenericImageView, ImageReader}, remove_background::{ + error::AppError, model::Model, postprocessing::{apply_mask, create_side_by_side}, sessions::init_session, }, show_image::{AsImageView, create_window, event}, std::{ - error::Error, fs, io::{Cursor, Read, Write, stdin, stdout}, path::PathBuf, @@ -68,7 +69,7 @@ struct Args { } #[show_image::main] -fn main() -> Result<(), Box> { +fn main() -> Result<()> { fmt() .with_env_filter( EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), @@ -85,7 +86,8 @@ fn main() -> Result<(), Box> { "Starting remove_background" ); - let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)?; + let mut session = init_session(args.model_path.as_deref(), args.model, args.offline) + .context("failed to initialize inference session")?; match args.command { Command::Single { @@ -95,36 +97,52 @@ fn main() -> Result<(), Box> { } => { let img = if let Some(input_file) = input_file { debug!(path = %input_file.display(), "Reading input image from file"); - ImageReader::open(input_file)?.decode()? + ImageReader::open(&input_file) + .with_context(|| { + format!("failed to open input image {}", input_file.display()) + })? + .decode() + .with_context(|| { + format!("failed to decode input image {}", input_file.display()) + })? } else { debug!("Reading input image from stdin"); let mut bytes = Vec::new(); if stdin().lock().read_to_end(&mut bytes)? == 0 { - return Err("No input data provided via stdin".into()); + return Err(AppError::NoStdinInput.into()); } ImageReader::new(Cursor::new(bytes)) .with_guessed_format()? - .decode()? + .decode() + .context("failed to decode input image from stdin")? }; let (img_width, img_height) = img.dimensions(); info!(width = img_width, height = img_height, "Loaded image"); - let mask = session.predict(&img)?; - let result_rgba = apply_mask(&img, &mask)?; + let mask = session + .predict(&img) + .context("failed to predict segmentation mask")?; + let result_rgba = apply_mask(&img, &mask).context("failed to apply mask to image")?; if debug { debug_mode(&img, &result_rgba)?; } else { if let Some(output_file) = output_file { - result_rgba.save(&output_file)?; + result_rgba.save(&output_file).with_context(|| { + format!("failed to save output image {}", output_file.display()) + })?; info!(path = %output_file.display(), "Wrote output image to file"); } else { debug!("Writing output image to stdout (PNG)"); let mut buffer = Cursor::new(Vec::new()); - result_rgba.write_to(&mut buffer, image::ImageFormat::Png)?; + result_rgba + .write_to(&mut buffer, image::ImageFormat::Png) + .context("failed to encode PNG for stdout output")?; let mut stdout = stdout().lock(); - stdout.write_all(&buffer.into_inner())?; + stdout + .write_all(&buffer.into_inner()) + .context("failed to write PNG bytes to stdout")?; } } } @@ -132,7 +150,12 @@ fn main() -> Result<(), Box> { input_directory, output_directory, } => { - fs::create_dir_all(&output_directory)?; + fs::create_dir_all(&output_directory).with_context(|| { + format!( + "failed to create output directory {}", + output_directory.display() + ) + })?; info!( input = %input_directory.display(), output = %output_directory.display(), @@ -141,7 +164,12 @@ fn main() -> Result<(), Box> { let mut processed: usize = 0; let mut failed: usize = 0; - for entry in fs::read_dir(&input_directory)? { + for entry in fs::read_dir(&input_directory).with_context(|| { + format!( + "failed to read input directory {}", + input_directory.display() + ) + })? { let entry = entry?; let path = entry.path(); let span = tracing::info_span!("batch_item", path = %path.display()); @@ -163,11 +191,16 @@ fn main() -> Result<(), Box> { } }; - let mask = session.predict(&img)?; - let result_rgba = apply_mask(&img, &mask)?; + let mask = session + .predict(&img) + .with_context(|| format!("failed to predict mask for {}", path.display()))?; + let result_rgba = apply_mask(&img, &mask) + .with_context(|| format!("failed to apply mask for {}", path.display()))?; let mut output_path = output_directory.join(path.file_name().unwrap()); output_path.set_extension("png"); - result_rgba.save(&output_path)?; + result_rgba.save(&output_path).with_context(|| { + format!("failed to save processed image {}", output_path.display()) + })?; processed += 1; info!(output = %output_path.display(), "Processed image saved"); } @@ -180,12 +213,10 @@ fn main() -> Result<(), Box> { Ok(()) } -fn debug_mode( - img: &image::DynamicImage, - result_rgba: &image::RgbaImage, -) -> Result<(), Box> { +fn debug_mode(img: &image::DynamicImage, result_rgba: &image::RgbaImage) -> Result<()> { info!("Creating side-by-side comparison"); - let composite = create_side_by_side(img, result_rgba)?; + let composite = create_side_by_side(img, result_rgba) + .context("failed to create side-by-side debug image")?; let composite_dynamic = image::DynamicImage::ImageRgba8(composite); let (comp_width, comp_height) = composite_dynamic.dimensions(); @@ -197,7 +228,7 @@ fn debug_mode( "comparison", composite_dynamic .as_image_view() - .map_err(|e| e.to_string())?, + .map_err(|e| anyhow::anyhow!(e.to_string()))?, )?; info!( diff --git a/src/model.rs b/src/model.rs index 18911d2..4778fe5 100644 --- a/src/model.rs +++ b/src/model.rs @@ -6,7 +6,6 @@ use { }, sha2::{Digest, Sha256}, std::{ - error::Error, fs, io::Read, path::{Path, PathBuf}, @@ -15,6 +14,8 @@ use { tracing::{debug, info, warn}, }; +use crate::error::{AppError, AppResult}; + /// CLI-facing model selector. Concrete session metadata (URL, checksum, /// preprocessing params) lives on the `Session` trait impls in `sessions/`. #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] @@ -25,7 +26,7 @@ pub enum Model { } /// Get the cache directory for models -fn get_cache_dir() -> Result> { +fn get_cache_dir() -> AppResult { let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"))?; let cache_dir = Path::new(&home) .join(".cache") @@ -35,7 +36,7 @@ fn get_cache_dir() -> Result> { Ok(cache_dir) } -fn download_file(url: &str, dest: &Path) -> Result<(), Box> { +fn download_file(url: &str, dest: &Path) -> AppResult<()> { info!(%url, dest = %dest.display(), "Downloading model"); let start = Instant::now(); @@ -46,7 +47,9 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box> { let mut response = client.get(url).send()?; if !response.status().is_success() { - return Err(format!("Failed to download model: HTTP {}", response.status()).into()); + return Err(AppError::DownloadHttpStatus { + status: response.status(), + }); } let mut file = std::io::BufWriter::new(fs::File::create(dest)?); @@ -87,7 +90,7 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box> { Ok(()) } -fn verify_hash(file_path: &Path, expected_hash: &str) -> Result> { +fn verify_hash(file_path: &Path, expected_hash: &str) -> AppResult { const BUF_SIZE: usize = 8192; let mut file = fs::File::open(file_path)?; @@ -116,11 +119,11 @@ pub fn get_model_path( sha256: Option<&str>, custom_path: Option<&str>, offline: bool, -) -> Result> { +) -> AppResult { if let Some(path) = custom_path { let model_path = PathBuf::from(path); if !model_path.exists() { - return Err(format!("Custom model path does not exist: {}", path).into()); + return Err(AppError::CustomModelPathMissing { path: model_path }); } info!(%path, "Using custom model"); return Ok(model_path); @@ -139,9 +142,9 @@ pub fn get_model_path( debug!("Cached model hash OK"); } else { warn!(path = %model_path.display(), "Cached model hash verification failed"); - return Err( - "Model hash verification failed. Try deleting the cached model.".into(), - ); + return Err(AppError::ModelHashVerificationFailed { + path: model_path.clone(), + }); } } @@ -149,11 +152,9 @@ pub fn get_model_path( } if offline { - return Err(format!( - "Model not found in cache and offline mode is enabled. Cache path: {}", - model_path.display() - ) - .into()); + return Err(AppError::OfflineModelMissing { + cache_path: model_path, + }); } info!("Model not found in cache, downloading"); @@ -165,7 +166,7 @@ pub fn get_model_path( debug!("Downloaded model hash OK"); } else { fs::remove_file(&model_path)?; - return Err("Downloaded model hash verification failed".into()); + return Err(AppError::ModelHashVerificationFailed { path: model_path }); } } @@ -173,14 +174,23 @@ pub fn get_model_path( } /// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU). -pub fn create_session(model_path: &Path) -> Result> { +pub fn create_session(model_path: &Path) -> AppResult { info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend"); let start = Instant::now(); let session = Session::builder()? - .with_execution_providers([CUDAExecutionProvider::default().build()])? - .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_intra_threads(4)? + .with_execution_providers([CUDAExecutionProvider::default().build()]) + .map_err(|err| AppError::OrtSessionBuilder { + message: err.to_string(), + })? + .with_optimization_level(GraphOptimizationLevel::Level3) + .map_err(|err| AppError::OrtSessionBuilder { + message: err.to_string(), + })? + .with_intra_threads(4) + .map_err(|err| AppError::OrtSessionBuilder { + message: err.to_string(), + })? .commit_from_file(model_path)?; info!( @@ -190,3 +200,54 @@ pub fn create_session(model_path: &Path) -> Result> { Ok(session) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn custom_model_path_missing_returns_typed_error() { + let missing_path = PathBuf::from("/definitely/not/a/real/model.onnx"); + let result = get_model_path( + "unused", + "https://example.invalid/model.onnx", + None, + Some(missing_path.to_str().expect("valid utf-8 path")), + false, + ); + + match result { + Err(AppError::CustomModelPathMissing { path }) => assert_eq!(path, missing_path), + other => panic!("unexpected result: {other:?}"), + } + } + + #[test] + fn offline_missing_model_returns_typed_error() { + let unique_name = format!( + "test-model-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system time before epoch") + .as_nanos() + ); + let result = get_model_path( + &unique_name, + "https://example.invalid/model.onnx", + None, + None, + true, + ); + + match result { + Err(AppError::OfflineModelMissing { cache_path }) => { + let expected = format!("{unique_name}.onnx"); + assert_eq!( + cache_path.file_name().and_then(|f| f.to_str()), + Some(expected.as_str()) + ); + } + other => panic!("unexpected result: {other:?}"), + } + } +} diff --git a/src/postprocessing.rs b/src/postprocessing.rs index 8a8afa2..0363ef3 100644 --- a/src/postprocessing.rs +++ b/src/postprocessing.rs @@ -1,14 +1,15 @@ use { image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType}, - std::error::Error, tracing::debug, }; +use crate::error::AppResult; + /// Compose `original` with `mask` as the alpha channel and return an RGBA image. /// /// The mask is expected to already be grayscale. If its dimensions differ from /// the original, it is resized with LANCZOS3. -pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result> { +pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> AppResult { let (orig_width, orig_height) = original.dimensions(); let (mask_width, mask_height) = mask.dimensions(); @@ -50,10 +51,7 @@ pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result Result> { +pub fn create_side_by_side(original: &DynamicImage, result: &RgbaImage) -> AppResult { let (width, height) = original.dimensions(); let mut composite = RgbaImage::new(width * 2, height); diff --git a/src/sessions/birefnet_lite.rs b/src/sessions/birefnet_lite.rs index e8f7949..adf7cb1 100644 --- a/src/sessions/birefnet_lite.rs +++ b/src/sessions/birefnet_lite.rs @@ -1,8 +1,9 @@ -use ort::session::Session as OrtSession; -use std::error::Error; -use std::path::Path; +use {ort::session::Session as OrtSession, std::path::Path}; -use crate::model::{create_session, get_model_path}; +use crate::{ + error::AppResult, + model::{create_session, get_model_path}, +}; use super::Session; @@ -12,12 +13,12 @@ pub struct BiRefNetLiteSession { } impl BiRefNetLiteSession { - pub fn new(offline: bool) -> Result> { + pub fn new(offline: bool) -> AppResult { let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?; Self::from_model_path(&path) } - pub fn from_model_path(path: &Path) -> Result> { + pub fn from_model_path(path: &Path) -> AppResult { let inner_session = create_session(path)?; let input_name = inner_session.inputs()[0].name().to_string(); Ok(Self { diff --git a/src/sessions/bria.rs b/src/sessions/bria.rs index 4cae5f5..dabf0b8 100644 --- a/src/sessions/bria.rs +++ b/src/sessions/bria.rs @@ -1,8 +1,9 @@ -use ort::session::Session as OrtSession; -use std::error::Error; -use std::path::Path; +use {ort::session::Session as OrtSession, std::path::Path}; -use crate::model::{create_session, get_model_path}; +use crate::{ + error::AppResult, + model::{create_session, get_model_path}, +}; use super::Session; @@ -12,12 +13,12 @@ pub struct BriaSession { } impl BriaSession { - pub fn new(offline: bool) -> Result> { + pub fn new(offline: bool) -> AppResult { let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?; Self::from_model_path(&path) } - pub fn from_model_path(path: &Path) -> Result> { + pub fn from_model_path(path: &Path) -> AppResult { let inner_session = create_session(path)?; let input_name = inner_session.inputs()[0].name().to_string(); Ok(Self { diff --git a/src/sessions/mod.rs b/src/sessions/mod.rs index 1f20727..e9d1a2a 100644 --- a/src/sessions/mod.rs +++ b/src/sessions/mod.rs @@ -2,11 +2,14 @@ use { image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType}, ndarray::{Array4, IntoDimension}, ort::{session::Session as OrtSession, value::Tensor}, - std::{error::Error, time::Instant}, + std::time::Instant, tracing::{debug, info}, }; -use crate::model::Model; +use crate::{ + error::{AppError, AppResult}, + model::Model, +}; mod birefnet_lite; mod bria; @@ -62,7 +65,7 @@ pub trait Session { /// Port of rembg's `BaseSession.normalize`: resize with LANCZOS, /// scale into `[0, 1]` by dividing by the max pixel value, then apply /// channel-wise mean/std. - fn normalize(&self, img: &DynamicImage) -> Result, Box> { + fn normalize(&self, img: &DynamicImage) -> AppResult> { let (w, h) = self.input_size(); let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_rgb8(); let (width, height) = resized.dimensions(); @@ -104,7 +107,7 @@ pub trait Session { /// 3. optional sigmoid (birefnet) /// 4. min/max normalize into `[0, 1]` /// 5. scale to `u8`, resize to original image dimensions - fn predict(&mut self, img: &DynamicImage) -> Result> { + fn predict(&mut self, img: &DynamicImage) -> AppResult { let (orig_w, orig_h) = img.dimensions(); let preprocess_start = Instant::now(); @@ -130,11 +133,9 @@ pub trait Session { let (shape, data) = output.try_extract_tensor::()?; if shape.len() != 4 { - return Err(format!( - "Expected 4D output tensor [N, C, H, W], got shape {:?}", - shape - ) - .into()); + return Err(AppError::UnexpectedTensorShape { + shape: shape.iter().copied().collect(), + }); } let (n, _c, h, w) = ( shape[0] as usize, @@ -143,7 +144,7 @@ pub trait Session { shape[3] as usize, ); if n != 1 { - return Err(format!("Expected batch size 1, got {}", n).into()); + return Err(AppError::UnexpectedBatchSize { batch_size: n }); } let view = ndarray::ArrayView::from_shape( @@ -203,11 +204,11 @@ pub fn init_session( custom_model_path: Option<&str>, model: Model, offline: bool, -) -> Result, Box> { +) -> AppResult> { Ok(if let Some(custom_path) = custom_model_path { let path = std::path::PathBuf::from(custom_path); if !path.exists() { - return Err(format!("Custom model path does not exist: {}", custom_path).into()); + return Err(AppError::CustomModelPathMissing { path }); } info!(path = %custom_path, "Using custom model"); match model {