Compare commits

...

5 commits

Author SHA1 Message Date
Matthew Deville
25a4a204b9 make cuda a feature 2026-04-23 18:13:40 +02:00
Matthew Deville
5660af24f7 depdency for cuda 2026-04-23 17:36:20 +02:00
Matthew Deville
c217578536 error handling 2026-04-23 14:51:16 +02:00
Matthew Deville
ebdbe13f26 fmt 2026-04-23 14:28:19 +02:00
Matthew Deville
3be68fcccf png as default output 2026-04-23 14:26:13 +02:00
12 changed files with 288 additions and 106 deletions

6
Cargo.lock generated
View file

@ -2750,6 +2750,7 @@ checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
name = "remove_background" name = "remove_background"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow",
"clap", "clap",
"hex", "hex",
"image", "image",
@ -2758,6 +2759,7 @@ dependencies = [
"reqwest", "reqwest",
"sha2", "sha2",
"show-image", "show-image",
"thiserror 2.0.18",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
] ]
@ -2870,9 +2872,9 @@ dependencies = [
[[package]] [[package]]
name = "rustls" name = "rustls"
version = "0.23.38" version = "0.23.39"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69f9466fb2c14ea04357e91413efb882e2a6d4a406e625449bc0a5d360d53a21" checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e"
dependencies = [ dependencies = [
"aws-lc-rs", "aws-lc-rs",
"once_cell", "once_cell",

View file

@ -3,14 +3,20 @@
name = "remove_background" name = "remove_background"
version = "0.1.0" version = "0.1.0"
[features]
default = []
cuda = ["ort/cuda"]
[dependencies] [dependencies]
anyhow = "1"
clap = { version = "4", features = ["derive"] } clap = { version = "4", features = ["derive"] }
hex = "0.4" hex = "0.4"
image = "0.25" image = "0.25"
ndarray = "0.17" ndarray = "0.17"
ort = "=2.0.0-rc.12" ort = { version = "=2.0.0-rc.12" }
reqwest = { version = "0.13", features = ["blocking"] } reqwest = { version = "0.13", features = ["blocking"] }
sha2 = "0.11" sha2 = "0.11"
show-image = { version = "0.14", features = ["image"] } show-image = { version = "0.14", features = ["image"] }
thiserror = "2"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }

View file

@ -62,11 +62,11 @@
"nixpkgs": "nixpkgs_2" "nixpkgs": "nixpkgs_2"
}, },
"locked": { "locked": {
"lastModified": 1776827647, "lastModified": 1776914043,
"narHash": "sha256-sYixYhp5V8jCajO8TRorE4fzs7IkL4MZdfLTKgkPQBk=", "narHash": "sha256-qug5r56yW1qOsjSI99l3Jm15JNT9CvS2otkXNRNtrPI=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "40e6ccc06e1245a4837cbbd6bdda64e21cc67379", "rev": "2d35c4358d7de3a0e606a6e8b27925d981c01cc3",
"type": "github" "type": "github"
}, },
"original": { "original": {

View file

@ -18,7 +18,10 @@
system: system:
let let
overlays = [ (import rust-overlay) ]; overlays = [ (import rust-overlay) ];
pkgs = import nixpkgs { inherit system overlays; }; pkgs = import nixpkgs {
inherit system overlays;
config.allowUnfree = true;
};
stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv; stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv;
rustToolchain = pkgs.rust-bin.stable.latest.default.override { rustToolchain = pkgs.rust-bin.stable.latest.default.override {
extensions = [ "rust-src" ]; extensions = [ "rust-src" ];
@ -32,10 +35,10 @@
]; ];
xorgBuildInputs = with pkgs; [ xorgBuildInputs = with pkgs; [
xorg.libx11 libx11
xorg.libxcursor libxcursor
xorg.libxi libxi
xorg.libxrandr libxrandr
]; ];
waylandBuildInputs = with pkgs; [ waylandBuildInputs = with pkgs; [
@ -49,7 +52,20 @@
openssl openssl
]; ];
buildInputs = xorgBuildInputs ++ waylandBuildInputs ++ graphicsBuildInputs; # CUDA runtime libraries required by ONNX Runtime CUDA EP.
cudaBuildInputs = with pkgs.cudaPackages; [
cuda_cudart
libcublas
libcurand
libcufft
cudnn
];
buildInputs =
xorgBuildInputs
++ waylandBuildInputs
++ graphicsBuildInputs
++ cudaBuildInputs;
mkShell = pkgs.mkShell.override { mkShell = pkgs.mkShell.override {
stdenv = stdenv; stdenv = stdenv;

48
src/error.rs Normal file
View file

@ -0,0 +1,48 @@
use {std::path::PathBuf, thiserror::Error};
pub type AppResult<T> = Result<T, AppError>;
#[derive(Debug, Error)]
pub enum AppError {
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
EnvVar(#[from] std::env::VarError),
#[error(transparent)]
Image(#[from] image::ImageError),
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
#[error(transparent)]
Ort(#[from] ort::Error),
#[error("failed to configure ONNX session builder: {message}")]
OrtSessionBuilder { message: String },
#[error(transparent)]
NdarrayShape(#[from] ndarray::ShapeError),
#[error("no input data provided via stdin")]
NoStdinInput,
#[error("download failed: HTTP {status}")]
DownloadHttpStatus { status: reqwest::StatusCode },
#[error("custom model path does not exist: {path}")]
CustomModelPathMissing { path: PathBuf },
#[error("model hash verification failed: {path}")]
ModelHashVerificationFailed { path: PathBuf },
#[error("model not found in cache and offline mode is enabled: {cache_path}")]
OfflineModelMissing { cache_path: PathBuf },
#[error("expected 4D output tensor [N, C, H, W], got shape {shape:?}")]
UnexpectedTensorShape { shape: Vec<i64> },
#[error("expected batch size 1, got {batch_size}")]
UnexpectedBatchSize { batch_size: usize },
}

View file

@ -1,3 +1,4 @@
pub mod error;
pub mod model; pub mod model;
pub mod postprocessing; pub mod postprocessing;
pub mod sessions; pub mod sessions;

View file

@ -1,14 +1,15 @@
use { use {
anyhow::{Context, Result},
clap::{Parser, Subcommand}, clap::{Parser, Subcommand},
image::{GenericImageView, ImageReader}, image::{GenericImageView, ImageReader},
remove_background::{ remove_background::{
error::AppError,
model::Model, model::Model,
postprocessing::{apply_mask, create_side_by_side}, postprocessing::{apply_mask, create_side_by_side},
sessions::init_session, sessions::init_session,
}, },
show_image::{AsImageView, create_window, event}, show_image::{AsImageView, create_window, event},
std::{ std::{
error::Error,
fs, fs,
io::{Cursor, Read, Write, stdin, stdout}, io::{Cursor, Read, Write, stdin, stdout},
path::PathBuf, path::PathBuf,
@ -68,9 +69,11 @@ struct Args {
} }
#[show_image::main] #[show_image::main]
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<()> {
fmt() fmt()
.with_env_filter(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"))) .with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
)
.with_writer(std::io::stderr) .with_writer(std::io::stderr)
.init(); .init();
@ -83,7 +86,8 @@ fn main() -> Result<(), Box<dyn Error>> {
"Starting remove_background" "Starting remove_background"
); );
let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)?; let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)
.context("failed to initialize inference session")?;
match args.command { match args.command {
Command::Single { Command::Single {
@ -93,44 +97,65 @@ fn main() -> Result<(), Box<dyn Error>> {
} => { } => {
let img = if let Some(input_file) = input_file { let img = if let Some(input_file) = input_file {
debug!(path = %input_file.display(), "Reading input image from file"); debug!(path = %input_file.display(), "Reading input image from file");
ImageReader::open(input_file)?.decode()? ImageReader::open(&input_file)
.with_context(|| {
format!("failed to open input image {}", input_file.display())
})?
.decode()
.with_context(|| {
format!("failed to decode input image {}", input_file.display())
})?
} else { } else {
debug!("Reading input image from stdin"); debug!("Reading input image from stdin");
let mut bytes = Vec::new(); let mut bytes = Vec::new();
if stdin().lock().read_to_end(&mut bytes)? == 0 { if stdin().lock().read_to_end(&mut bytes)? == 0 {
return Err("No input data provided via stdin".into()); return Err(AppError::NoStdinInput.into());
} }
ImageReader::new(Cursor::new(bytes)) ImageReader::new(Cursor::new(bytes))
.with_guessed_format()? .with_guessed_format()?
.decode()? .decode()
.context("failed to decode input image from stdin")?
}; };
let (img_width, img_height) = img.dimensions(); let (img_width, img_height) = img.dimensions();
info!(width = img_width, height = img_height, "Loaded image"); info!(width = img_width, height = img_height, "Loaded image");
let mask = session.predict(&img)?; let mask = session
let result_rgba = apply_mask(&img, &mask)?; .predict(&img)
.context("failed to predict segmentation mask")?;
let result_rgba = apply_mask(&img, &mask).context("failed to apply mask to image")?;
if debug { if debug {
debug_mode(&img, &result_rgba)?; debug_mode(&img, &result_rgba)?;
} } else {
if let Some(output_file) = output_file { if let Some(output_file) = output_file {
result_rgba.save(&output_file)?; result_rgba.save(&output_file).with_context(|| {
format!("failed to save output image {}", output_file.display())
})?;
info!(path = %output_file.display(), "Wrote output image to file"); info!(path = %output_file.display(), "Wrote output image to file");
} else { } else {
debug!("Writing output image to stdout (PNG)"); debug!("Writing output image to stdout (PNG)");
let mut buffer = Cursor::new(Vec::new()); let mut buffer = Cursor::new(Vec::new());
result_rgba.write_to(&mut buffer, image::ImageFormat::Png)?; result_rgba
.write_to(&mut buffer, image::ImageFormat::Png)
.context("failed to encode PNG for stdout output")?;
let mut stdout = stdout().lock(); let mut stdout = stdout().lock();
stdout.write_all(&buffer.into_inner())?; stdout
.write_all(&buffer.into_inner())
.context("failed to write PNG bytes to stdout")?;
}
} }
} }
Command::Batch { Command::Batch {
input_directory, input_directory,
output_directory, output_directory,
} => { } => {
fs::create_dir_all(&output_directory)?; fs::create_dir_all(&output_directory).with_context(|| {
format!(
"failed to create output directory {}",
output_directory.display()
)
})?;
info!( info!(
input = %input_directory.display(), input = %input_directory.display(),
output = %output_directory.display(), output = %output_directory.display(),
@ -139,7 +164,12 @@ fn main() -> Result<(), Box<dyn Error>> {
let mut processed: usize = 0; let mut processed: usize = 0;
let mut failed: usize = 0; let mut failed: usize = 0;
for entry in fs::read_dir(&input_directory)? { for entry in fs::read_dir(&input_directory).with_context(|| {
format!(
"failed to read input directory {}",
input_directory.display()
)
})? {
let entry = entry?; let entry = entry?;
let path = entry.path(); let path = entry.path();
let span = tracing::info_span!("batch_item", path = %path.display()); let span = tracing::info_span!("batch_item", path = %path.display());
@ -161,10 +191,16 @@ fn main() -> Result<(), Box<dyn Error>> {
} }
}; };
let mask = session.predict(&img)?; let mask = session
let result_rgba = apply_mask(&img, &mask)?; .predict(&img)
let output_path = output_directory.join(path.file_name().unwrap()); .with_context(|| format!("failed to predict mask for {}", path.display()))?;
result_rgba.save(&output_path)?; let result_rgba = apply_mask(&img, &mask)
.with_context(|| format!("failed to apply mask for {}", path.display()))?;
let mut output_path = output_directory.join(path.file_name().unwrap());
output_path.set_extension("png");
result_rgba.save(&output_path).with_context(|| {
format!("failed to save processed image {}", output_path.display())
})?;
processed += 1; processed += 1;
info!(output = %output_path.display(), "Processed image saved"); info!(output = %output_path.display(), "Processed image saved");
} }
@ -177,12 +213,10 @@ fn main() -> Result<(), Box<dyn Error>> {
Ok(()) Ok(())
} }
fn debug_mode( fn debug_mode(img: &image::DynamicImage, result_rgba: &image::RgbaImage) -> Result<()> {
img: &image::DynamicImage,
result_rgba: &image::RgbaImage,
) -> Result<(), Box<dyn Error>> {
info!("Creating side-by-side comparison"); info!("Creating side-by-side comparison");
let composite = create_side_by_side(img, result_rgba)?; let composite = create_side_by_side(img, result_rgba)
.context("failed to create side-by-side debug image")?;
let composite_dynamic = image::DynamicImage::ImageRgba8(composite); let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
let (comp_width, comp_height) = composite_dynamic.dimensions(); let (comp_width, comp_height) = composite_dynamic.dimensions();
@ -194,7 +228,7 @@ fn debug_mode(
"comparison", "comparison",
composite_dynamic composite_dynamic
.as_image_view() .as_image_view()
.map_err(|e| e.to_string())?, .map_err(|e| anyhow::anyhow!(e.to_string()))?,
)?; )?;
info!( info!(

View file

@ -1,12 +1,8 @@
use { use {
clap::ValueEnum, clap::ValueEnum,
ort::{ ort::session::Session,
execution_providers::CUDAExecutionProvider,
session::{Session, builder::GraphOptimizationLevel},
},
sha2::{Digest, Sha256}, sha2::{Digest, Sha256},
std::{ std::{
error::Error,
fs, fs,
io::Read, io::Read,
path::{Path, PathBuf}, path::{Path, PathBuf},
@ -15,6 +11,11 @@ use {
tracing::{debug, info, warn}, tracing::{debug, info, warn},
}; };
use crate::error::{AppError, AppResult};
#[cfg(feature = "cuda")]
use ort::{builder::GraphOptimizationLevel, ep::CUDAExecutionProvider};
/// CLI-facing model selector. Concrete session metadata (URL, checksum, /// CLI-facing model selector. Concrete session metadata (URL, checksum,
/// preprocessing params) lives on the `Session` trait impls in `sessions/`. /// preprocessing params) lives on the `Session` trait impls in `sessions/`.
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
@ -25,7 +26,7 @@ pub enum Model {
} }
/// Get the cache directory for models /// Get the cache directory for models
fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> { fn get_cache_dir() -> AppResult<PathBuf> {
let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"))?; let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"))?;
let cache_dir = Path::new(&home) let cache_dir = Path::new(&home)
.join(".cache") .join(".cache")
@ -35,7 +36,7 @@ fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
Ok(cache_dir) Ok(cache_dir)
} }
fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> { fn download_file(url: &str, dest: &Path) -> AppResult<()> {
info!(%url, dest = %dest.display(), "Downloading model"); info!(%url, dest = %dest.display(), "Downloading model");
let start = Instant::now(); let start = Instant::now();
@ -46,7 +47,9 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
let mut response = client.get(url).send()?; let mut response = client.get(url).send()?;
if !response.status().is_success() { if !response.status().is_success() {
return Err(format!("Failed to download model: HTTP {}", response.status()).into()); return Err(AppError::DownloadHttpStatus {
status: response.status(),
});
} }
let mut file = std::io::BufWriter::new(fs::File::create(dest)?); let mut file = std::io::BufWriter::new(fs::File::create(dest)?);
@ -68,7 +71,12 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
// Only emit every 10% to avoid log spam while still giving feedback. // Only emit every 10% to avoid log spam while still giving feedback.
if progress >= last_reported_pct + 10 { if progress >= last_reported_pct + 10 {
last_reported_pct = progress - (progress % 10); last_reported_pct = progress - (progress % 10);
debug!(percent = last_reported_pct, downloaded, total = total_size, "Download progress"); debug!(
percent = last_reported_pct,
downloaded,
total = total_size,
"Download progress"
);
} }
} }
} }
@ -82,7 +90,7 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
Ok(()) Ok(())
} }
fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Error>> { fn verify_hash(file_path: &Path, expected_hash: &str) -> AppResult<bool> {
const BUF_SIZE: usize = 8192; const BUF_SIZE: usize = 8192;
let mut file = fs::File::open(file_path)?; let mut file = fs::File::open(file_path)?;
@ -111,11 +119,11 @@ pub fn get_model_path(
sha256: Option<&str>, sha256: Option<&str>,
custom_path: Option<&str>, custom_path: Option<&str>,
offline: bool, offline: bool,
) -> Result<PathBuf, Box<dyn Error>> { ) -> AppResult<PathBuf> {
if let Some(path) = custom_path { if let Some(path) = custom_path {
let model_path = PathBuf::from(path); let model_path = PathBuf::from(path);
if !model_path.exists() { if !model_path.exists() {
return Err(format!("Custom model path does not exist: {}", path).into()); return Err(AppError::CustomModelPathMissing { path: model_path });
} }
info!(%path, "Using custom model"); info!(%path, "Using custom model");
return Ok(model_path); return Ok(model_path);
@ -134,9 +142,9 @@ pub fn get_model_path(
debug!("Cached model hash OK"); debug!("Cached model hash OK");
} else { } else {
warn!(path = %model_path.display(), "Cached model hash verification failed"); warn!(path = %model_path.display(), "Cached model hash verification failed");
return Err( return Err(AppError::ModelHashVerificationFailed {
"Model hash verification failed. Try deleting the cached model.".into(), path: model_path.clone(),
); });
} }
} }
@ -144,11 +152,9 @@ pub fn get_model_path(
} }
if offline { if offline {
return Err(format!( return Err(AppError::OfflineModelMissing {
"Model not found in cache and offline mode is enabled. Cache path: {}", cache_path: model_path,
model_path.display() });
)
.into());
} }
info!("Model not found in cache, downloading"); info!("Model not found in cache, downloading");
@ -160,23 +166,40 @@ pub fn get_model_path(
debug!("Downloaded model hash OK"); debug!("Downloaded model hash OK");
} else { } else {
fs::remove_file(&model_path)?; fs::remove_file(&model_path)?;
return Err("Downloaded model hash verification failed".into()); return Err(AppError::ModelHashVerificationFailed { path: model_path });
} }
} }
Ok(model_path) Ok(model_path)
} }
/// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU). /// Create an ONNX Runtime session from a model path.
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> { ///
/// Uses the CUDA execution provider when built with `--features cuda`; otherwise runs on CPU.
pub fn create_session(model_path: &Path) -> AppResult<Session> {
#[cfg(feature = "cuda")]
info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend"); info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend");
#[cfg(not(feature = "cuda"))]
info!(path = %model_path.display(), "Loading model into ONNX Runtime with CPU backend");
let start = Instant::now(); let start = Instant::now();
let session = Session::builder()? let mut builder = Session::builder()?;
.with_execution_providers([CUDAExecutionProvider::default().build()])? #[cfg(feature = "cuda")]
.with_optimization_level(GraphOptimizationLevel::Level3)? let builder = builder
.with_intra_threads(4)? .with_execution_providers([CUDAExecutionProvider::default().build()])
.commit_from_file(model_path)?; .map_err(|err| AppError::OrtSessionBuilder {
message: err.to_string(),
})?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|err| AppError::OrtSessionBuilder {
message: err.to_string(),
})?
.with_intra_threads(4)
.map_err(|err| AppError::OrtSessionBuilder {
message: err.to_string(),
})?;
let session = builder.commit_from_file(model_path)?;
info!( info!(
elapsed_ms = start.elapsed().as_millis() as u64, elapsed_ms = start.elapsed().as_millis() as u64,
@ -185,3 +208,54 @@ pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
Ok(session) Ok(session)
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn custom_model_path_missing_returns_typed_error() {
let missing_path = PathBuf::from("/definitely/not/a/real/model.onnx");
let result = get_model_path(
"unused",
"https://example.invalid/model.onnx",
None,
Some(missing_path.to_str().expect("valid utf-8 path")),
false,
);
match result {
Err(AppError::CustomModelPathMissing { path }) => assert_eq!(path, missing_path),
other => panic!("unexpected result: {other:?}"),
}
}
#[test]
fn offline_missing_model_returns_typed_error() {
let unique_name = format!(
"test-model-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system time before epoch")
.as_nanos()
);
let result = get_model_path(
&unique_name,
"https://example.invalid/model.onnx",
None,
None,
true,
);
match result {
Err(AppError::OfflineModelMissing { cache_path }) => {
let expected = format!("{unique_name}.onnx");
assert_eq!(
cache_path.file_name().and_then(|f| f.to_str()),
Some(expected.as_str())
);
}
other => panic!("unexpected result: {other:?}"),
}
}
}

View file

@ -1,14 +1,15 @@
use { use {
image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType}, image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType},
std::error::Error,
tracing::debug, tracing::debug,
}; };
use crate::error::AppResult;
/// Compose `original` with `mask` as the alpha channel and return an RGBA image. /// Compose `original` with `mask` as the alpha channel and return an RGBA image.
/// ///
/// The mask is expected to already be grayscale. If its dimensions differ from /// The mask is expected to already be grayscale. If its dimensions differ from
/// the original, it is resized with LANCZOS3. /// the original, it is resized with LANCZOS3.
pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage, Box<dyn Error>> { pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> AppResult<RgbaImage> {
let (orig_width, orig_height) = original.dimensions(); let (orig_width, orig_height) = original.dimensions();
let (mask_width, mask_height) = mask.dimensions(); let (mask_width, mask_height) = mask.dimensions();
@ -16,10 +17,7 @@ pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage
if mask_width != orig_width || mask_height != orig_height { if mask_width != orig_width || mask_height != orig_height {
debug!( debug!(
mask_width, mask_width,
mask_height, mask_height, orig_width, orig_height, "Resizing mask to match original image"
orig_width,
orig_height,
"Resizing mask to match original image"
); );
std::borrow::Cow::Owned(image::imageops::resize( std::borrow::Cow::Owned(image::imageops::resize(
mask, mask,
@ -53,10 +51,7 @@ pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage
} }
/// Create a side-by-side comparison image /// Create a side-by-side comparison image
pub fn create_side_by_side( pub fn create_side_by_side(original: &DynamicImage, result: &RgbaImage) -> AppResult<RgbaImage> {
original: &DynamicImage,
result: &RgbaImage,
) -> Result<RgbaImage, Box<dyn Error>> {
let (width, height) = original.dimensions(); let (width, height) = original.dimensions();
let mut composite = RgbaImage::new(width * 2, height); let mut composite = RgbaImage::new(width * 2, height);

View file

@ -1,8 +1,9 @@
use ort::session::Session as OrtSession; use {ort::session::Session as OrtSession, std::path::Path};
use std::error::Error;
use std::path::Path;
use crate::model::{create_session, get_model_path}; use crate::{
error::AppResult,
model::{create_session, get_model_path},
};
use super::Session; use super::Session;
@ -12,12 +13,12 @@ pub struct BiRefNetLiteSession {
} }
impl BiRefNetLiteSession { impl BiRefNetLiteSession {
pub fn new(offline: bool) -> Result<Self, Box<dyn Error>> { pub fn new(offline: bool) -> AppResult<Self> {
let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?; let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?;
Self::from_model_path(&path) Self::from_model_path(&path)
} }
pub fn from_model_path(path: &Path) -> Result<Self, Box<dyn Error>> { pub fn from_model_path(path: &Path) -> AppResult<Self> {
let inner_session = create_session(path)?; let inner_session = create_session(path)?;
let input_name = inner_session.inputs()[0].name().to_string(); let input_name = inner_session.inputs()[0].name().to_string();
Ok(Self { Ok(Self {

View file

@ -1,8 +1,9 @@
use ort::session::Session as OrtSession; use {ort::session::Session as OrtSession, std::path::Path};
use std::error::Error;
use std::path::Path;
use crate::model::{create_session, get_model_path}; use crate::{
error::AppResult,
model::{create_session, get_model_path},
};
use super::Session; use super::Session;
@ -12,12 +13,12 @@ pub struct BriaSession {
} }
impl BriaSession { impl BriaSession {
pub fn new(offline: bool) -> Result<Self, Box<dyn Error>> { pub fn new(offline: bool) -> AppResult<Self> {
let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?; let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?;
Self::from_model_path(&path) Self::from_model_path(&path)
} }
pub fn from_model_path(path: &Path) -> Result<Self, Box<dyn Error>> { pub fn from_model_path(path: &Path) -> AppResult<Self> {
let inner_session = create_session(path)?; let inner_session = create_session(path)?;
let input_name = inner_session.inputs()[0].name().to_string(); let input_name = inner_session.inputs()[0].name().to_string();
Ok(Self { Ok(Self {

View file

@ -2,11 +2,14 @@ use {
image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType}, image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType},
ndarray::{Array4, IntoDimension}, ndarray::{Array4, IntoDimension},
ort::{session::Session as OrtSession, value::Tensor}, ort::{session::Session as OrtSession, value::Tensor},
std::{error::Error, time::Instant}, std::time::Instant,
tracing::{debug, info}, tracing::{debug, info},
}; };
use crate::model::Model; use crate::{
error::{AppError, AppResult},
model::Model,
};
mod birefnet_lite; mod birefnet_lite;
mod bria; mod bria;
@ -62,7 +65,7 @@ pub trait Session {
/// Port of rembg's `BaseSession.normalize`: resize with LANCZOS, /// Port of rembg's `BaseSession.normalize`: resize with LANCZOS,
/// scale into `[0, 1]` by dividing by the max pixel value, then apply /// scale into `[0, 1]` by dividing by the max pixel value, then apply
/// channel-wise mean/std. /// channel-wise mean/std.
fn normalize(&self, img: &DynamicImage) -> Result<Array4<f32>, Box<dyn Error>> { fn normalize(&self, img: &DynamicImage) -> AppResult<Array4<f32>> {
let (w, h) = self.input_size(); let (w, h) = self.input_size();
let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_rgb8(); let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_rgb8();
let (width, height) = resized.dimensions(); let (width, height) = resized.dimensions();
@ -104,12 +107,15 @@ pub trait Session {
/// 3. optional sigmoid (birefnet) /// 3. optional sigmoid (birefnet)
/// 4. min/max normalize into `[0, 1]` /// 4. min/max normalize into `[0, 1]`
/// 5. scale to `u8`, resize to original image dimensions /// 5. scale to `u8`, resize to original image dimensions
fn predict(&mut self, img: &DynamicImage) -> Result<GrayImage, Box<dyn Error>> { fn predict(&mut self, img: &DynamicImage) -> AppResult<GrayImage> {
let (orig_w, orig_h) = img.dimensions(); let (orig_w, orig_h) = img.dimensions();
let preprocess_start = Instant::now(); let preprocess_start = Instant::now();
let input = self.normalize(img)?; let input = self.normalize(img)?;
debug!(elapsed_ms = preprocess_start.elapsed().as_millis() as u64, "Preprocessing complete"); debug!(
elapsed_ms = preprocess_start.elapsed().as_millis() as u64,
"Preprocessing complete"
);
let apply_sigmoid = self.apply_sigmoid(); let apply_sigmoid = self.apply_sigmoid();
let input_name = self.input_name().to_string(); let input_name = self.input_name().to_string();
@ -127,11 +133,9 @@ pub trait Session {
let (shape, data) = output.try_extract_tensor::<f32>()?; let (shape, data) = output.try_extract_tensor::<f32>()?;
if shape.len() != 4 { if shape.len() != 4 {
return Err(format!( return Err(AppError::UnexpectedTensorShape {
"Expected 4D output tensor [N, C, H, W], got shape {:?}", shape: shape.iter().copied().collect(),
shape });
)
.into());
} }
let (n, _c, h, w) = ( let (n, _c, h, w) = (
shape[0] as usize, shape[0] as usize,
@ -140,7 +144,7 @@ pub trait Session {
shape[3] as usize, shape[3] as usize,
); );
if n != 1 { if n != 1 {
return Err(format!("Expected batch size 1, got {}", n).into()); return Err(AppError::UnexpectedBatchSize { batch_size: n });
} }
let view = ndarray::ArrayView::from_shape( let view = ndarray::ArrayView::from_shape(
@ -200,11 +204,11 @@ pub fn init_session(
custom_model_path: Option<&str>, custom_model_path: Option<&str>,
model: Model, model: Model,
offline: bool, offline: bool,
) -> Result<Box<dyn Session>, Box<dyn Error>> { ) -> AppResult<Box<dyn Session>> {
Ok(if let Some(custom_path) = custom_model_path { Ok(if let Some(custom_path) = custom_model_path {
let path = std::path::PathBuf::from(custom_path); let path = std::path::PathBuf::from(custom_path);
if !path.exists() { if !path.exists() {
return Err(format!("Custom model path does not exist: {}", custom_path).into()); return Err(AppError::CustomModelPathMissing { path });
} }
info!(path = %custom_path, "Using custom model"); info!(path = %custom_path, "Using custom model");
match model { match model {