Compare commits
5 commits
31d84c99d7
...
25a4a204b9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25a4a204b9 | ||
|
|
5660af24f7 | ||
|
|
c217578536 | ||
|
|
ebdbe13f26 | ||
|
|
3be68fcccf |
12 changed files with 288 additions and 106 deletions
6
Cargo.lock
generated
6
Cargo.lock
generated
|
|
@ -2750,6 +2750,7 @@ checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
|
||||||
name = "remove_background"
|
name = "remove_background"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
"clap",
|
"clap",
|
||||||
"hex",
|
"hex",
|
||||||
"image",
|
"image",
|
||||||
|
|
@ -2758,6 +2759,7 @@ dependencies = [
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"sha2",
|
"sha2",
|
||||||
"show-image",
|
"show-image",
|
||||||
|
"thiserror 2.0.18",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
|
|
@ -2870,9 +2872,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls"
|
name = "rustls"
|
||||||
version = "0.23.38"
|
version = "0.23.39"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "69f9466fb2c14ea04357e91413efb882e2a6d4a406e625449bc0a5d360d53a21"
|
checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aws-lc-rs",
|
"aws-lc-rs",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,20 @@
|
||||||
name = "remove_background"
|
name = "remove_background"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
cuda = ["ort/cuda"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
anyhow = "1"
|
||||||
clap = { version = "4", features = ["derive"] }
|
clap = { version = "4", features = ["derive"] }
|
||||||
hex = "0.4"
|
hex = "0.4"
|
||||||
image = "0.25"
|
image = "0.25"
|
||||||
ndarray = "0.17"
|
ndarray = "0.17"
|
||||||
ort = "=2.0.0-rc.12"
|
ort = { version = "=2.0.0-rc.12" }
|
||||||
reqwest = { version = "0.13", features = ["blocking"] }
|
reqwest = { version = "0.13", features = ["blocking"] }
|
||||||
sha2 = "0.11"
|
sha2 = "0.11"
|
||||||
show-image = { version = "0.14", features = ["image"] }
|
show-image = { version = "0.14", features = ["image"] }
|
||||||
|
thiserror = "2"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
|
||||||
|
|
|
||||||
|
|
@ -62,11 +62,11 @@
|
||||||
"nixpkgs": "nixpkgs_2"
|
"nixpkgs": "nixpkgs_2"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1776827647,
|
"lastModified": 1776914043,
|
||||||
"narHash": "sha256-sYixYhp5V8jCajO8TRorE4fzs7IkL4MZdfLTKgkPQBk=",
|
"narHash": "sha256-qug5r56yW1qOsjSI99l3Jm15JNT9CvS2otkXNRNtrPI=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "40e6ccc06e1245a4837cbbd6bdda64e21cc67379",
|
"rev": "2d35c4358d7de3a0e606a6e8b27925d981c01cc3",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|
|
||||||
28
flake.nix
28
flake.nix
|
|
@ -18,7 +18,10 @@
|
||||||
system:
|
system:
|
||||||
let
|
let
|
||||||
overlays = [ (import rust-overlay) ];
|
overlays = [ (import rust-overlay) ];
|
||||||
pkgs = import nixpkgs { inherit system overlays; };
|
pkgs = import nixpkgs {
|
||||||
|
inherit system overlays;
|
||||||
|
config.allowUnfree = true;
|
||||||
|
};
|
||||||
stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv;
|
stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv;
|
||||||
rustToolchain = pkgs.rust-bin.stable.latest.default.override {
|
rustToolchain = pkgs.rust-bin.stable.latest.default.override {
|
||||||
extensions = [ "rust-src" ];
|
extensions = [ "rust-src" ];
|
||||||
|
|
@ -32,10 +35,10 @@
|
||||||
];
|
];
|
||||||
|
|
||||||
xorgBuildInputs = with pkgs; [
|
xorgBuildInputs = with pkgs; [
|
||||||
xorg.libx11
|
libx11
|
||||||
xorg.libxcursor
|
libxcursor
|
||||||
xorg.libxi
|
libxi
|
||||||
xorg.libxrandr
|
libxrandr
|
||||||
];
|
];
|
||||||
|
|
||||||
waylandBuildInputs = with pkgs; [
|
waylandBuildInputs = with pkgs; [
|
||||||
|
|
@ -49,7 +52,20 @@
|
||||||
openssl
|
openssl
|
||||||
];
|
];
|
||||||
|
|
||||||
buildInputs = xorgBuildInputs ++ waylandBuildInputs ++ graphicsBuildInputs;
|
# CUDA runtime libraries required by ONNX Runtime CUDA EP.
|
||||||
|
cudaBuildInputs = with pkgs.cudaPackages; [
|
||||||
|
cuda_cudart
|
||||||
|
libcublas
|
||||||
|
libcurand
|
||||||
|
libcufft
|
||||||
|
cudnn
|
||||||
|
];
|
||||||
|
|
||||||
|
buildInputs =
|
||||||
|
xorgBuildInputs
|
||||||
|
++ waylandBuildInputs
|
||||||
|
++ graphicsBuildInputs
|
||||||
|
++ cudaBuildInputs;
|
||||||
|
|
||||||
mkShell = pkgs.mkShell.override {
|
mkShell = pkgs.mkShell.override {
|
||||||
stdenv = stdenv;
|
stdenv = stdenv;
|
||||||
|
|
|
||||||
48
src/error.rs
Normal file
48
src/error.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
use {std::path::PathBuf, thiserror::Error};
|
||||||
|
|
||||||
|
pub type AppResult<T> = Result<T, AppError>;
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum AppError {
|
||||||
|
#[error(transparent)]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
EnvVar(#[from] std::env::VarError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Image(#[from] image::ImageError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Reqwest(#[from] reqwest::Error),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Ort(#[from] ort::Error),
|
||||||
|
|
||||||
|
#[error("failed to configure ONNX session builder: {message}")]
|
||||||
|
OrtSessionBuilder { message: String },
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
NdarrayShape(#[from] ndarray::ShapeError),
|
||||||
|
|
||||||
|
#[error("no input data provided via stdin")]
|
||||||
|
NoStdinInput,
|
||||||
|
|
||||||
|
#[error("download failed: HTTP {status}")]
|
||||||
|
DownloadHttpStatus { status: reqwest::StatusCode },
|
||||||
|
|
||||||
|
#[error("custom model path does not exist: {path}")]
|
||||||
|
CustomModelPathMissing { path: PathBuf },
|
||||||
|
|
||||||
|
#[error("model hash verification failed: {path}")]
|
||||||
|
ModelHashVerificationFailed { path: PathBuf },
|
||||||
|
|
||||||
|
#[error("model not found in cache and offline mode is enabled: {cache_path}")]
|
||||||
|
OfflineModelMissing { cache_path: PathBuf },
|
||||||
|
|
||||||
|
#[error("expected 4D output tensor [N, C, H, W], got shape {shape:?}")]
|
||||||
|
UnexpectedTensorShape { shape: Vec<i64> },
|
||||||
|
|
||||||
|
#[error("expected batch size 1, got {batch_size}")]
|
||||||
|
UnexpectedBatchSize { batch_size: usize },
|
||||||
|
}
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
pub mod error;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod postprocessing;
|
pub mod postprocessing;
|
||||||
pub mod sessions;
|
pub mod sessions;
|
||||||
|
|
|
||||||
86
src/main.rs
86
src/main.rs
|
|
@ -1,14 +1,15 @@
|
||||||
use {
|
use {
|
||||||
|
anyhow::{Context, Result},
|
||||||
clap::{Parser, Subcommand},
|
clap::{Parser, Subcommand},
|
||||||
image::{GenericImageView, ImageReader},
|
image::{GenericImageView, ImageReader},
|
||||||
remove_background::{
|
remove_background::{
|
||||||
|
error::AppError,
|
||||||
model::Model,
|
model::Model,
|
||||||
postprocessing::{apply_mask, create_side_by_side},
|
postprocessing::{apply_mask, create_side_by_side},
|
||||||
sessions::init_session,
|
sessions::init_session,
|
||||||
},
|
},
|
||||||
show_image::{AsImageView, create_window, event},
|
show_image::{AsImageView, create_window, event},
|
||||||
std::{
|
std::{
|
||||||
error::Error,
|
|
||||||
fs,
|
fs,
|
||||||
io::{Cursor, Read, Write, stdin, stdout},
|
io::{Cursor, Read, Write, stdin, stdout},
|
||||||
path::PathBuf,
|
path::PathBuf,
|
||||||
|
|
@ -68,9 +69,11 @@ struct Args {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[show_image::main]
|
#[show_image::main]
|
||||||
fn main() -> Result<(), Box<dyn Error>> {
|
fn main() -> Result<()> {
|
||||||
fmt()
|
fmt()
|
||||||
.with_env_filter(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")))
|
.with_env_filter(
|
||||||
|
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
|
||||||
|
)
|
||||||
.with_writer(std::io::stderr)
|
.with_writer(std::io::stderr)
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
|
|
@ -83,7 +86,8 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
"Starting remove_background"
|
"Starting remove_background"
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)?;
|
let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)
|
||||||
|
.context("failed to initialize inference session")?;
|
||||||
|
|
||||||
match args.command {
|
match args.command {
|
||||||
Command::Single {
|
Command::Single {
|
||||||
|
|
@ -93,44 +97,65 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
} => {
|
} => {
|
||||||
let img = if let Some(input_file) = input_file {
|
let img = if let Some(input_file) = input_file {
|
||||||
debug!(path = %input_file.display(), "Reading input image from file");
|
debug!(path = %input_file.display(), "Reading input image from file");
|
||||||
ImageReader::open(input_file)?.decode()?
|
ImageReader::open(&input_file)
|
||||||
|
.with_context(|| {
|
||||||
|
format!("failed to open input image {}", input_file.display())
|
||||||
|
})?
|
||||||
|
.decode()
|
||||||
|
.with_context(|| {
|
||||||
|
format!("failed to decode input image {}", input_file.display())
|
||||||
|
})?
|
||||||
} else {
|
} else {
|
||||||
debug!("Reading input image from stdin");
|
debug!("Reading input image from stdin");
|
||||||
let mut bytes = Vec::new();
|
let mut bytes = Vec::new();
|
||||||
if stdin().lock().read_to_end(&mut bytes)? == 0 {
|
if stdin().lock().read_to_end(&mut bytes)? == 0 {
|
||||||
return Err("No input data provided via stdin".into());
|
return Err(AppError::NoStdinInput.into());
|
||||||
}
|
}
|
||||||
ImageReader::new(Cursor::new(bytes))
|
ImageReader::new(Cursor::new(bytes))
|
||||||
.with_guessed_format()?
|
.with_guessed_format()?
|
||||||
.decode()?
|
.decode()
|
||||||
|
.context("failed to decode input image from stdin")?
|
||||||
};
|
};
|
||||||
|
|
||||||
let (img_width, img_height) = img.dimensions();
|
let (img_width, img_height) = img.dimensions();
|
||||||
info!(width = img_width, height = img_height, "Loaded image");
|
info!(width = img_width, height = img_height, "Loaded image");
|
||||||
|
|
||||||
let mask = session.predict(&img)?;
|
let mask = session
|
||||||
let result_rgba = apply_mask(&img, &mask)?;
|
.predict(&img)
|
||||||
|
.context("failed to predict segmentation mask")?;
|
||||||
|
let result_rgba = apply_mask(&img, &mask).context("failed to apply mask to image")?;
|
||||||
|
|
||||||
if debug {
|
if debug {
|
||||||
debug_mode(&img, &result_rgba)?;
|
debug_mode(&img, &result_rgba)?;
|
||||||
}
|
} else {
|
||||||
|
|
||||||
if let Some(output_file) = output_file {
|
if let Some(output_file) = output_file {
|
||||||
result_rgba.save(&output_file)?;
|
result_rgba.save(&output_file).with_context(|| {
|
||||||
|
format!("failed to save output image {}", output_file.display())
|
||||||
|
})?;
|
||||||
info!(path = %output_file.display(), "Wrote output image to file");
|
info!(path = %output_file.display(), "Wrote output image to file");
|
||||||
} else {
|
} else {
|
||||||
debug!("Writing output image to stdout (PNG)");
|
debug!("Writing output image to stdout (PNG)");
|
||||||
let mut buffer = Cursor::new(Vec::new());
|
let mut buffer = Cursor::new(Vec::new());
|
||||||
result_rgba.write_to(&mut buffer, image::ImageFormat::Png)?;
|
result_rgba
|
||||||
|
.write_to(&mut buffer, image::ImageFormat::Png)
|
||||||
|
.context("failed to encode PNG for stdout output")?;
|
||||||
let mut stdout = stdout().lock();
|
let mut stdout = stdout().lock();
|
||||||
stdout.write_all(&buffer.into_inner())?;
|
stdout
|
||||||
|
.write_all(&buffer.into_inner())
|
||||||
|
.context("failed to write PNG bytes to stdout")?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Command::Batch {
|
Command::Batch {
|
||||||
input_directory,
|
input_directory,
|
||||||
output_directory,
|
output_directory,
|
||||||
} => {
|
} => {
|
||||||
fs::create_dir_all(&output_directory)?;
|
fs::create_dir_all(&output_directory).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"failed to create output directory {}",
|
||||||
|
output_directory.display()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
info!(
|
info!(
|
||||||
input = %input_directory.display(),
|
input = %input_directory.display(),
|
||||||
output = %output_directory.display(),
|
output = %output_directory.display(),
|
||||||
|
|
@ -139,7 +164,12 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
|
||||||
let mut processed: usize = 0;
|
let mut processed: usize = 0;
|
||||||
let mut failed: usize = 0;
|
let mut failed: usize = 0;
|
||||||
for entry in fs::read_dir(&input_directory)? {
|
for entry in fs::read_dir(&input_directory).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"failed to read input directory {}",
|
||||||
|
input_directory.display()
|
||||||
|
)
|
||||||
|
})? {
|
||||||
let entry = entry?;
|
let entry = entry?;
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
let span = tracing::info_span!("batch_item", path = %path.display());
|
let span = tracing::info_span!("batch_item", path = %path.display());
|
||||||
|
|
@ -161,10 +191,16 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mask = session.predict(&img)?;
|
let mask = session
|
||||||
let result_rgba = apply_mask(&img, &mask)?;
|
.predict(&img)
|
||||||
let output_path = output_directory.join(path.file_name().unwrap());
|
.with_context(|| format!("failed to predict mask for {}", path.display()))?;
|
||||||
result_rgba.save(&output_path)?;
|
let result_rgba = apply_mask(&img, &mask)
|
||||||
|
.with_context(|| format!("failed to apply mask for {}", path.display()))?;
|
||||||
|
let mut output_path = output_directory.join(path.file_name().unwrap());
|
||||||
|
output_path.set_extension("png");
|
||||||
|
result_rgba.save(&output_path).with_context(|| {
|
||||||
|
format!("failed to save processed image {}", output_path.display())
|
||||||
|
})?;
|
||||||
processed += 1;
|
processed += 1;
|
||||||
info!(output = %output_path.display(), "Processed image saved");
|
info!(output = %output_path.display(), "Processed image saved");
|
||||||
}
|
}
|
||||||
|
|
@ -177,12 +213,10 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn debug_mode(
|
fn debug_mode(img: &image::DynamicImage, result_rgba: &image::RgbaImage) -> Result<()> {
|
||||||
img: &image::DynamicImage,
|
|
||||||
result_rgba: &image::RgbaImage,
|
|
||||||
) -> Result<(), Box<dyn Error>> {
|
|
||||||
info!("Creating side-by-side comparison");
|
info!("Creating side-by-side comparison");
|
||||||
let composite = create_side_by_side(img, result_rgba)?;
|
let composite = create_side_by_side(img, result_rgba)
|
||||||
|
.context("failed to create side-by-side debug image")?;
|
||||||
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
|
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
|
||||||
|
|
||||||
let (comp_width, comp_height) = composite_dynamic.dimensions();
|
let (comp_width, comp_height) = composite_dynamic.dimensions();
|
||||||
|
|
@ -194,7 +228,7 @@ fn debug_mode(
|
||||||
"comparison",
|
"comparison",
|
||||||
composite_dynamic
|
composite_dynamic
|
||||||
.as_image_view()
|
.as_image_view()
|
||||||
.map_err(|e| e.to_string())?,
|
.map_err(|e| anyhow::anyhow!(e.to_string()))?,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
|
|
|
||||||
130
src/model.rs
130
src/model.rs
|
|
@ -1,12 +1,8 @@
|
||||||
use {
|
use {
|
||||||
clap::ValueEnum,
|
clap::ValueEnum,
|
||||||
ort::{
|
ort::session::Session,
|
||||||
execution_providers::CUDAExecutionProvider,
|
|
||||||
session::{Session, builder::GraphOptimizationLevel},
|
|
||||||
},
|
|
||||||
sha2::{Digest, Sha256},
|
sha2::{Digest, Sha256},
|
||||||
std::{
|
std::{
|
||||||
error::Error,
|
|
||||||
fs,
|
fs,
|
||||||
io::Read,
|
io::Read,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
|
|
@ -15,6 +11,11 @@ use {
|
||||||
tracing::{debug, info, warn},
|
tracing::{debug, info, warn},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::error::{AppError, AppResult};
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use ort::{builder::GraphOptimizationLevel, ep::CUDAExecutionProvider};
|
||||||
|
|
||||||
/// CLI-facing model selector. Concrete session metadata (URL, checksum,
|
/// CLI-facing model selector. Concrete session metadata (URL, checksum,
|
||||||
/// preprocessing params) lives on the `Session` trait impls in `sessions/`.
|
/// preprocessing params) lives on the `Session` trait impls in `sessions/`.
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
|
@ -25,7 +26,7 @@ pub enum Model {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the cache directory for models
|
/// Get the cache directory for models
|
||||||
fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
|
fn get_cache_dir() -> AppResult<PathBuf> {
|
||||||
let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"))?;
|
let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"))?;
|
||||||
let cache_dir = Path::new(&home)
|
let cache_dir = Path::new(&home)
|
||||||
.join(".cache")
|
.join(".cache")
|
||||||
|
|
@ -35,7 +36,7 @@ fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
|
||||||
Ok(cache_dir)
|
Ok(cache_dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
|
fn download_file(url: &str, dest: &Path) -> AppResult<()> {
|
||||||
info!(%url, dest = %dest.display(), "Downloading model");
|
info!(%url, dest = %dest.display(), "Downloading model");
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
|
|
@ -46,7 +47,9 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
|
||||||
let mut response = client.get(url).send()?;
|
let mut response = client.get(url).send()?;
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
return Err(format!("Failed to download model: HTTP {}", response.status()).into());
|
return Err(AppError::DownloadHttpStatus {
|
||||||
|
status: response.status(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut file = std::io::BufWriter::new(fs::File::create(dest)?);
|
let mut file = std::io::BufWriter::new(fs::File::create(dest)?);
|
||||||
|
|
@ -68,7 +71,12 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
|
||||||
// Only emit every 10% to avoid log spam while still giving feedback.
|
// Only emit every 10% to avoid log spam while still giving feedback.
|
||||||
if progress >= last_reported_pct + 10 {
|
if progress >= last_reported_pct + 10 {
|
||||||
last_reported_pct = progress - (progress % 10);
|
last_reported_pct = progress - (progress % 10);
|
||||||
debug!(percent = last_reported_pct, downloaded, total = total_size, "Download progress");
|
debug!(
|
||||||
|
percent = last_reported_pct,
|
||||||
|
downloaded,
|
||||||
|
total = total_size,
|
||||||
|
"Download progress"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -82,7 +90,7 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Error>> {
|
fn verify_hash(file_path: &Path, expected_hash: &str) -> AppResult<bool> {
|
||||||
const BUF_SIZE: usize = 8192;
|
const BUF_SIZE: usize = 8192;
|
||||||
|
|
||||||
let mut file = fs::File::open(file_path)?;
|
let mut file = fs::File::open(file_path)?;
|
||||||
|
|
@ -111,11 +119,11 @@ pub fn get_model_path(
|
||||||
sha256: Option<&str>,
|
sha256: Option<&str>,
|
||||||
custom_path: Option<&str>,
|
custom_path: Option<&str>,
|
||||||
offline: bool,
|
offline: bool,
|
||||||
) -> Result<PathBuf, Box<dyn Error>> {
|
) -> AppResult<PathBuf> {
|
||||||
if let Some(path) = custom_path {
|
if let Some(path) = custom_path {
|
||||||
let model_path = PathBuf::from(path);
|
let model_path = PathBuf::from(path);
|
||||||
if !model_path.exists() {
|
if !model_path.exists() {
|
||||||
return Err(format!("Custom model path does not exist: {}", path).into());
|
return Err(AppError::CustomModelPathMissing { path: model_path });
|
||||||
}
|
}
|
||||||
info!(%path, "Using custom model");
|
info!(%path, "Using custom model");
|
||||||
return Ok(model_path);
|
return Ok(model_path);
|
||||||
|
|
@ -134,9 +142,9 @@ pub fn get_model_path(
|
||||||
debug!("Cached model hash OK");
|
debug!("Cached model hash OK");
|
||||||
} else {
|
} else {
|
||||||
warn!(path = %model_path.display(), "Cached model hash verification failed");
|
warn!(path = %model_path.display(), "Cached model hash verification failed");
|
||||||
return Err(
|
return Err(AppError::ModelHashVerificationFailed {
|
||||||
"Model hash verification failed. Try deleting the cached model.".into(),
|
path: model_path.clone(),
|
||||||
);
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -144,11 +152,9 @@ pub fn get_model_path(
|
||||||
}
|
}
|
||||||
|
|
||||||
if offline {
|
if offline {
|
||||||
return Err(format!(
|
return Err(AppError::OfflineModelMissing {
|
||||||
"Model not found in cache and offline mode is enabled. Cache path: {}",
|
cache_path: model_path,
|
||||||
model_path.display()
|
});
|
||||||
)
|
|
||||||
.into());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Model not found in cache, downloading");
|
info!("Model not found in cache, downloading");
|
||||||
|
|
@ -160,23 +166,40 @@ pub fn get_model_path(
|
||||||
debug!("Downloaded model hash OK");
|
debug!("Downloaded model hash OK");
|
||||||
} else {
|
} else {
|
||||||
fs::remove_file(&model_path)?;
|
fs::remove_file(&model_path)?;
|
||||||
return Err("Downloaded model hash verification failed".into());
|
return Err(AppError::ModelHashVerificationFailed { path: model_path });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(model_path)
|
Ok(model_path)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU).
|
/// Create an ONNX Runtime session from a model path.
|
||||||
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
///
|
||||||
|
/// Uses the CUDA execution provider when built with `--features cuda`; otherwise runs on CPU.
|
||||||
|
pub fn create_session(model_path: &Path) -> AppResult<Session> {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend");
|
info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend");
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
info!(path = %model_path.display(), "Loading model into ONNX Runtime with CPU backend");
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let session = Session::builder()?
|
let mut builder = Session::builder()?;
|
||||||
.with_execution_providers([CUDAExecutionProvider::default().build()])?
|
#[cfg(feature = "cuda")]
|
||||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
let builder = builder
|
||||||
.with_intra_threads(4)?
|
.with_execution_providers([CUDAExecutionProvider::default().build()])
|
||||||
.commit_from_file(model_path)?;
|
.map_err(|err| AppError::OrtSessionBuilder {
|
||||||
|
message: err.to_string(),
|
||||||
|
})?
|
||||||
|
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||||
|
.map_err(|err| AppError::OrtSessionBuilder {
|
||||||
|
message: err.to_string(),
|
||||||
|
})?
|
||||||
|
.with_intra_threads(4)
|
||||||
|
.map_err(|err| AppError::OrtSessionBuilder {
|
||||||
|
message: err.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let session = builder.commit_from_file(model_path)?;
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
elapsed_ms = start.elapsed().as_millis() as u64,
|
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||||
|
|
@ -185,3 +208,54 @@ pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
||||||
|
|
||||||
Ok(session)
|
Ok(session)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn custom_model_path_missing_returns_typed_error() {
|
||||||
|
let missing_path = PathBuf::from("/definitely/not/a/real/model.onnx");
|
||||||
|
let result = get_model_path(
|
||||||
|
"unused",
|
||||||
|
"https://example.invalid/model.onnx",
|
||||||
|
None,
|
||||||
|
Some(missing_path.to_str().expect("valid utf-8 path")),
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Err(AppError::CustomModelPathMissing { path }) => assert_eq!(path, missing_path),
|
||||||
|
other => panic!("unexpected result: {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn offline_missing_model_returns_typed_error() {
|
||||||
|
let unique_name = format!(
|
||||||
|
"test-model-{}",
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.expect("system time before epoch")
|
||||||
|
.as_nanos()
|
||||||
|
);
|
||||||
|
let result = get_model_path(
|
||||||
|
&unique_name,
|
||||||
|
"https://example.invalid/model.onnx",
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
true,
|
||||||
|
);
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Err(AppError::OfflineModelMissing { cache_path }) => {
|
||||||
|
let expected = format!("{unique_name}.onnx");
|
||||||
|
assert_eq!(
|
||||||
|
cache_path.file_name().and_then(|f| f.to_str()),
|
||||||
|
Some(expected.as_str())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
other => panic!("unexpected result: {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,15 @@
|
||||||
use {
|
use {
|
||||||
image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType},
|
image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType},
|
||||||
std::error::Error,
|
|
||||||
tracing::debug,
|
tracing::debug,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::error::AppResult;
|
||||||
|
|
||||||
/// Compose `original` with `mask` as the alpha channel and return an RGBA image.
|
/// Compose `original` with `mask` as the alpha channel and return an RGBA image.
|
||||||
///
|
///
|
||||||
/// The mask is expected to already be grayscale. If its dimensions differ from
|
/// The mask is expected to already be grayscale. If its dimensions differ from
|
||||||
/// the original, it is resized with LANCZOS3.
|
/// the original, it is resized with LANCZOS3.
|
||||||
pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage, Box<dyn Error>> {
|
pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> AppResult<RgbaImage> {
|
||||||
let (orig_width, orig_height) = original.dimensions();
|
let (orig_width, orig_height) = original.dimensions();
|
||||||
let (mask_width, mask_height) = mask.dimensions();
|
let (mask_width, mask_height) = mask.dimensions();
|
||||||
|
|
||||||
|
|
@ -16,10 +17,7 @@ pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage
|
||||||
if mask_width != orig_width || mask_height != orig_height {
|
if mask_width != orig_width || mask_height != orig_height {
|
||||||
debug!(
|
debug!(
|
||||||
mask_width,
|
mask_width,
|
||||||
mask_height,
|
mask_height, orig_width, orig_height, "Resizing mask to match original image"
|
||||||
orig_width,
|
|
||||||
orig_height,
|
|
||||||
"Resizing mask to match original image"
|
|
||||||
);
|
);
|
||||||
std::borrow::Cow::Owned(image::imageops::resize(
|
std::borrow::Cow::Owned(image::imageops::resize(
|
||||||
mask,
|
mask,
|
||||||
|
|
@ -53,10 +51,7 @@ pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a side-by-side comparison image
|
/// Create a side-by-side comparison image
|
||||||
pub fn create_side_by_side(
|
pub fn create_side_by_side(original: &DynamicImage, result: &RgbaImage) -> AppResult<RgbaImage> {
|
||||||
original: &DynamicImage,
|
|
||||||
result: &RgbaImage,
|
|
||||||
) -> Result<RgbaImage, Box<dyn Error>> {
|
|
||||||
let (width, height) = original.dimensions();
|
let (width, height) = original.dimensions();
|
||||||
let mut composite = RgbaImage::new(width * 2, height);
|
let mut composite = RgbaImage::new(width * 2, height);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
use ort::session::Session as OrtSession;
|
use {ort::session::Session as OrtSession, std::path::Path};
|
||||||
use std::error::Error;
|
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
use crate::model::{create_session, get_model_path};
|
use crate::{
|
||||||
|
error::AppResult,
|
||||||
|
model::{create_session, get_model_path},
|
||||||
|
};
|
||||||
|
|
||||||
use super::Session;
|
use super::Session;
|
||||||
|
|
||||||
|
|
@ -12,12 +13,12 @@ pub struct BiRefNetLiteSession {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BiRefNetLiteSession {
|
impl BiRefNetLiteSession {
|
||||||
pub fn new(offline: bool) -> Result<Self, Box<dyn Error>> {
|
pub fn new(offline: bool) -> AppResult<Self> {
|
||||||
let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?;
|
let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?;
|
||||||
Self::from_model_path(&path)
|
Self::from_model_path(&path)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_model_path(path: &Path) -> Result<Self, Box<dyn Error>> {
|
pub fn from_model_path(path: &Path) -> AppResult<Self> {
|
||||||
let inner_session = create_session(path)?;
|
let inner_session = create_session(path)?;
|
||||||
let input_name = inner_session.inputs()[0].name().to_string();
|
let input_name = inner_session.inputs()[0].name().to_string();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
use ort::session::Session as OrtSession;
|
use {ort::session::Session as OrtSession, std::path::Path};
|
||||||
use std::error::Error;
|
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
use crate::model::{create_session, get_model_path};
|
use crate::{
|
||||||
|
error::AppResult,
|
||||||
|
model::{create_session, get_model_path},
|
||||||
|
};
|
||||||
|
|
||||||
use super::Session;
|
use super::Session;
|
||||||
|
|
||||||
|
|
@ -12,12 +13,12 @@ pub struct BriaSession {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BriaSession {
|
impl BriaSession {
|
||||||
pub fn new(offline: bool) -> Result<Self, Box<dyn Error>> {
|
pub fn new(offline: bool) -> AppResult<Self> {
|
||||||
let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?;
|
let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?;
|
||||||
Self::from_model_path(&path)
|
Self::from_model_path(&path)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_model_path(path: &Path) -> Result<Self, Box<dyn Error>> {
|
pub fn from_model_path(path: &Path) -> AppResult<Self> {
|
||||||
let inner_session = create_session(path)?;
|
let inner_session = create_session(path)?;
|
||||||
let input_name = inner_session.inputs()[0].name().to_string();
|
let input_name = inner_session.inputs()[0].name().to_string();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,14 @@ use {
|
||||||
image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType},
|
image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType},
|
||||||
ndarray::{Array4, IntoDimension},
|
ndarray::{Array4, IntoDimension},
|
||||||
ort::{session::Session as OrtSession, value::Tensor},
|
ort::{session::Session as OrtSession, value::Tensor},
|
||||||
std::{error::Error, time::Instant},
|
std::time::Instant,
|
||||||
tracing::{debug, info},
|
tracing::{debug, info},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::model::Model;
|
use crate::{
|
||||||
|
error::{AppError, AppResult},
|
||||||
|
model::Model,
|
||||||
|
};
|
||||||
|
|
||||||
mod birefnet_lite;
|
mod birefnet_lite;
|
||||||
mod bria;
|
mod bria;
|
||||||
|
|
@ -62,7 +65,7 @@ pub trait Session {
|
||||||
/// Port of rembg's `BaseSession.normalize`: resize with LANCZOS,
|
/// Port of rembg's `BaseSession.normalize`: resize with LANCZOS,
|
||||||
/// scale into `[0, 1]` by dividing by the max pixel value, then apply
|
/// scale into `[0, 1]` by dividing by the max pixel value, then apply
|
||||||
/// channel-wise mean/std.
|
/// channel-wise mean/std.
|
||||||
fn normalize(&self, img: &DynamicImage) -> Result<Array4<f32>, Box<dyn Error>> {
|
fn normalize(&self, img: &DynamicImage) -> AppResult<Array4<f32>> {
|
||||||
let (w, h) = self.input_size();
|
let (w, h) = self.input_size();
|
||||||
let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_rgb8();
|
let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_rgb8();
|
||||||
let (width, height) = resized.dimensions();
|
let (width, height) = resized.dimensions();
|
||||||
|
|
@ -104,12 +107,15 @@ pub trait Session {
|
||||||
/// 3. optional sigmoid (birefnet)
|
/// 3. optional sigmoid (birefnet)
|
||||||
/// 4. min/max normalize into `[0, 1]`
|
/// 4. min/max normalize into `[0, 1]`
|
||||||
/// 5. scale to `u8`, resize to original image dimensions
|
/// 5. scale to `u8`, resize to original image dimensions
|
||||||
fn predict(&mut self, img: &DynamicImage) -> Result<GrayImage, Box<dyn Error>> {
|
fn predict(&mut self, img: &DynamicImage) -> AppResult<GrayImage> {
|
||||||
let (orig_w, orig_h) = img.dimensions();
|
let (orig_w, orig_h) = img.dimensions();
|
||||||
|
|
||||||
let preprocess_start = Instant::now();
|
let preprocess_start = Instant::now();
|
||||||
let input = self.normalize(img)?;
|
let input = self.normalize(img)?;
|
||||||
debug!(elapsed_ms = preprocess_start.elapsed().as_millis() as u64, "Preprocessing complete");
|
debug!(
|
||||||
|
elapsed_ms = preprocess_start.elapsed().as_millis() as u64,
|
||||||
|
"Preprocessing complete"
|
||||||
|
);
|
||||||
let apply_sigmoid = self.apply_sigmoid();
|
let apply_sigmoid = self.apply_sigmoid();
|
||||||
|
|
||||||
let input_name = self.input_name().to_string();
|
let input_name = self.input_name().to_string();
|
||||||
|
|
@ -127,11 +133,9 @@ pub trait Session {
|
||||||
let (shape, data) = output.try_extract_tensor::<f32>()?;
|
let (shape, data) = output.try_extract_tensor::<f32>()?;
|
||||||
|
|
||||||
if shape.len() != 4 {
|
if shape.len() != 4 {
|
||||||
return Err(format!(
|
return Err(AppError::UnexpectedTensorShape {
|
||||||
"Expected 4D output tensor [N, C, H, W], got shape {:?}",
|
shape: shape.iter().copied().collect(),
|
||||||
shape
|
});
|
||||||
)
|
|
||||||
.into());
|
|
||||||
}
|
}
|
||||||
let (n, _c, h, w) = (
|
let (n, _c, h, w) = (
|
||||||
shape[0] as usize,
|
shape[0] as usize,
|
||||||
|
|
@ -140,7 +144,7 @@ pub trait Session {
|
||||||
shape[3] as usize,
|
shape[3] as usize,
|
||||||
);
|
);
|
||||||
if n != 1 {
|
if n != 1 {
|
||||||
return Err(format!("Expected batch size 1, got {}", n).into());
|
return Err(AppError::UnexpectedBatchSize { batch_size: n });
|
||||||
}
|
}
|
||||||
|
|
||||||
let view = ndarray::ArrayView::from_shape(
|
let view = ndarray::ArrayView::from_shape(
|
||||||
|
|
@ -200,11 +204,11 @@ pub fn init_session(
|
||||||
custom_model_path: Option<&str>,
|
custom_model_path: Option<&str>,
|
||||||
model: Model,
|
model: Model,
|
||||||
offline: bool,
|
offline: bool,
|
||||||
) -> Result<Box<dyn Session>, Box<dyn Error>> {
|
) -> AppResult<Box<dyn Session>> {
|
||||||
Ok(if let Some(custom_path) = custom_model_path {
|
Ok(if let Some(custom_path) = custom_model_path {
|
||||||
let path = std::path::PathBuf::from(custom_path);
|
let path = std::path::PathBuf::from(custom_path);
|
||||||
if !path.exists() {
|
if !path.exists() {
|
||||||
return Err(format!("Custom model path does not exist: {}", custom_path).into());
|
return Err(AppError::CustomModelPathMissing { path });
|
||||||
}
|
}
|
||||||
info!(path = %custom_path, "Using custom model");
|
info!(path = %custom_path, "Using custom model");
|
||||||
match model {
|
match model {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue