refactor
This commit is contained in:
parent
efc3f18e63
commit
395990e47d
8 changed files with 363 additions and 230 deletions
3
src/lib.rs
Normal file
3
src/lib.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
pub mod model;
|
||||||
|
pub mod postprocessing;
|
||||||
|
pub mod sessions;
|
||||||
91
src/main.rs
91
src/main.rs
|
|
@ -1,17 +1,13 @@
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use image::GenericImageView;
|
use image::GenericImageView;
|
||||||
use ndarray::IntoDimension;
|
|
||||||
use ort::value::Tensor;
|
|
||||||
use show_image::{AsImageView, create_window, event};
|
use show_image::{AsImageView, create_window, event};
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
|
||||||
mod model;
|
use remove_background::{
|
||||||
mod postprocessing;
|
model::Model,
|
||||||
mod preprocessing;
|
postprocessing::{apply_mask, create_side_by_side},
|
||||||
|
sessions::{BiRefNetLiteSession, BriaSession, Session},
|
||||||
use model::{Model, create_session, get_model_path};
|
};
|
||||||
use postprocessing::{apply_mask, create_side_by_side};
|
|
||||||
use preprocessing::preprocess_image;
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "remove_background")]
|
#[command(name = "remove_background")]
|
||||||
|
|
@ -39,13 +35,11 @@ struct Args {
|
||||||
|
|
||||||
#[show_image::main]
|
#[show_image::main]
|
||||||
fn main() -> Result<(), Box<dyn Error>> {
|
fn main() -> Result<(), Box<dyn Error>> {
|
||||||
// Parse command line arguments
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
println!("=== Background Removal Tool ===\n");
|
println!("=== Background Removal Tool ===\n");
|
||||||
println!("Downloading image from: {}", args.url);
|
println!("Downloading image from: {}", args.url);
|
||||||
|
|
||||||
// Download the image with a user agent (using blocking client)
|
|
||||||
let client = reqwest::blocking::Client::builder()
|
let client = reqwest::blocking::Client::builder()
|
||||||
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
@ -58,69 +52,47 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
let bytes = response.bytes()?;
|
let bytes = response.bytes()?;
|
||||||
println!("Downloaded {} bytes", bytes.len());
|
println!("Downloaded {} bytes", bytes.len());
|
||||||
|
|
||||||
// Load the image
|
|
||||||
let img = image::load_from_memory(&bytes)?;
|
let img = image::load_from_memory(&bytes)?;
|
||||||
let (img_width, img_height) = img.dimensions();
|
let (img_width, img_height) = img.dimensions();
|
||||||
println!("Loaded image: {}x{}\n", img_width, img_height);
|
println!("Loaded image: {}x{}\n", img_width, img_height);
|
||||||
|
|
||||||
// Get model path - either from custom path or by selecting built-in model
|
let mut session: Box<dyn Session> = if let Some(custom_path) = args.model_path.as_deref() {
|
||||||
let model_path = if let Some(custom_path) = args.model_path {
|
let path = std::path::PathBuf::from(custom_path);
|
||||||
// Custom model path provided, use it directly
|
if !path.exists() {
|
||||||
let model_path = std::path::PathBuf::from(&custom_path);
|
|
||||||
if !model_path.exists() {
|
|
||||||
return Err(format!("Custom model path does not exist: {}", custom_path).into());
|
return Err(format!("Custom model path does not exist: {}", custom_path).into());
|
||||||
}
|
}
|
||||||
println!("Using custom model: {}", custom_path);
|
println!("Using custom model: {}", custom_path);
|
||||||
model_path
|
match args.model {
|
||||||
|
Model::Bria => Box::new(BriaSession::from_model_path(&path)?),
|
||||||
|
Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?),
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Use built-in model
|
match args.model {
|
||||||
let model_info = args.model.info();
|
Model::Bria => {
|
||||||
println!("Using model: {}", model_info.name);
|
println!("Using model: bria-rmbg");
|
||||||
get_model_path(&model_info, None, args.offline)?
|
Box::new(BriaSession::new(args.offline)?)
|
||||||
|
}
|
||||||
|
Model::BiRefNetLite => {
|
||||||
|
println!("Using model: birefnet-general-lite");
|
||||||
|
Box::new(BiRefNetLiteSession::new(args.offline)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create ONNX Runtime session
|
|
||||||
let mut session = create_session(&model_path)?;
|
|
||||||
let input_name = session.inputs()[0].name().to_string();
|
|
||||||
|
|
||||||
// Preprocess
|
|
||||||
let input_tensor = preprocess_image(&img, 1024, 1024)?;
|
|
||||||
|
|
||||||
// Run inference
|
|
||||||
println!("Running inference...");
|
println!("Running inference...");
|
||||||
let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?;
|
let mask = session.predict(&img)?;
|
||||||
|
|
||||||
// Extract mask output
|
|
||||||
let mask_output = &outputs[0];
|
|
||||||
let (mask_shape, mask_array) = mask_output.try_extract_tensor::<f32>()?; // Returns (shape, &[f32])
|
|
||||||
|
|
||||||
// Convert the slice to Array4
|
|
||||||
let mask_tensor = ndarray::ArrayView::from_shape(
|
|
||||||
(
|
|
||||||
mask_shape[0] as usize,
|
|
||||||
mask_shape[1] as usize,
|
|
||||||
mask_shape[2] as usize,
|
|
||||||
mask_shape[3] as usize,
|
|
||||||
)
|
|
||||||
.into_dimension(),
|
|
||||||
mask_array,
|
|
||||||
)?
|
|
||||||
.to_owned();
|
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"Inference complete! Output shape: {:?}",
|
"Inference complete! Mask dimensions: {}x{}",
|
||||||
mask_tensor.shape()
|
mask.dimensions().0,
|
||||||
|
mask.dimensions().1
|
||||||
);
|
);
|
||||||
|
|
||||||
// Apply mask to remove background
|
let result_rgba = apply_mask(&img, &mask)?;
|
||||||
let result_rgba = apply_mask(&img, mask_tensor)?;
|
|
||||||
|
|
||||||
// Create side-by-side comparison
|
|
||||||
println!("Creating side-by-side comparison...");
|
println!("Creating side-by-side comparison...");
|
||||||
let composite = create_side_by_side(&img, &result_rgba)?;
|
let composite = create_side_by_side(&img, &result_rgba)?;
|
||||||
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
|
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
|
||||||
|
|
||||||
// Display the result
|
|
||||||
let (comp_width, comp_height) = composite_dynamic.dimensions();
|
let (comp_width, comp_height) = composite_dynamic.dimensions();
|
||||||
let window = create_window(
|
let window = create_window(
|
||||||
"Background Removal - Original (Left) vs Result (Right) - Press ESC to close",
|
"Background Removal - Original (Left) vs Result (Right) - Press ESC to close",
|
||||||
|
|
@ -128,7 +100,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
)?;
|
)?;
|
||||||
window.set_image(
|
window.set_image(
|
||||||
"comparison",
|
"comparison",
|
||||||
&composite_dynamic
|
composite_dynamic
|
||||||
.as_image_view()
|
.as_image_view()
|
||||||
.map_err(|e| e.to_string())?,
|
.map_err(|e| e.to_string())?,
|
||||||
)?;
|
)?;
|
||||||
|
|
@ -141,17 +113,16 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
println!(" Left: Original image");
|
println!(" Left: Original image");
|
||||||
println!(" Right: Background removed (shown on checkered background)");
|
println!(" Right: Background removed (shown on checkered background)");
|
||||||
println!("\nPress ESC to close the window.");
|
println!("\nPress ESC to close the window.");
|
||||||
// Event loop - wait for ESC key to close
|
|
||||||
for event in window.event_channel()? {
|
for event in window.event_channel()? {
|
||||||
if let event::WindowEvent::KeyboardInput(event) = event {
|
if let event::WindowEvent::KeyboardInput(event) = event
|
||||||
if event.input.key_code == Some(event::VirtualKeyCode::Escape)
|
&& event.input.key_code == Some(event::VirtualKeyCode::Escape)
|
||||||
&& event.input.state.is_pressed()
|
&& event.input.state.is_pressed()
|
||||||
{
|
{
|
||||||
println!("ESC pressed, closing...");
|
println!("ESC pressed, closing...");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
63
src/model.rs
63
src/model.rs
|
|
@ -7,6 +7,8 @@ use std::fs;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
/// CLI-facing model selector. Concrete session metadata (URL, checksum,
|
||||||
|
/// preprocessing params) lives on the `Session` trait impls in `sessions/`.
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
#[clap(rename_all = "kebab-case")]
|
#[clap(rename_all = "kebab-case")]
|
||||||
pub enum Model {
|
pub enum Model {
|
||||||
|
|
@ -14,37 +16,6 @@ pub enum Model {
|
||||||
Bria,
|
Bria,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
|
||||||
pub fn info(&self) -> &ModelInfo {
|
|
||||||
match self {
|
|
||||||
Model::BiRefNetLite => &ModelInfo::BIREFNET_LITE,
|
|
||||||
Model::Bria => &ModelInfo::BRIA,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Model metadata
|
|
||||||
pub struct ModelInfo {
|
|
||||||
pub name: &'static str,
|
|
||||||
pub url: &'static str,
|
|
||||||
pub sha256: Option<&'static str>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ModelInfo {
|
|
||||||
/// BiRefNet General Lite model
|
|
||||||
pub const BIREFNET_LITE: ModelInfo = ModelInfo {
|
|
||||||
name: "birefnet-general-lite",
|
|
||||||
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
|
|
||||||
sha256: Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333"),
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const BRIA: ModelInfo = ModelInfo {
|
|
||||||
name: "bria",
|
|
||||||
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx",
|
|
||||||
sha256: Some("5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958"),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the cache directory for models
|
/// Get the cache directory for models
|
||||||
fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
|
fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
|
||||||
let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"))?;
|
let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"))?;
|
||||||
|
|
@ -56,7 +27,6 @@ fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
|
||||||
Ok(cache_dir)
|
Ok(cache_dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Download a file from URL to destination
|
|
||||||
fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
|
fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
|
||||||
println!("Downloading model from {}...", url);
|
println!("Downloading model from {}...", url);
|
||||||
|
|
||||||
|
|
@ -99,7 +69,6 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verify file SHA256 hash
|
|
||||||
fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Error>> {
|
fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Error>> {
|
||||||
let mut file = fs::File::open(file_path)?;
|
let mut file = fs::File::open(file_path)?;
|
||||||
let mut hasher = Sha256::new();
|
let mut hasher = Sha256::new();
|
||||||
|
|
@ -109,13 +78,20 @@ fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Er
|
||||||
Ok(hash_str == expected_hash)
|
Ok(hash_str == expected_hash)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get or download model, return path to model file
|
/// Resolve a model file to a local path, downloading + verifying if needed.
|
||||||
|
///
|
||||||
|
/// - `name`: cache filename stem (e.g. `"bria-rmbg"`)
|
||||||
|
/// - `url`: remote download URL
|
||||||
|
/// - `sha256`: optional expected SHA-256 of the ONNX file
|
||||||
|
/// - `custom_path`: if `Some`, bypass the cache and use this path directly
|
||||||
|
/// - `offline`: if true, fail instead of downloading when the cache is cold
|
||||||
pub fn get_model_path(
|
pub fn get_model_path(
|
||||||
model_info: &ModelInfo,
|
name: &str,
|
||||||
|
url: &str,
|
||||||
|
sha256: Option<&str>,
|
||||||
custom_path: Option<&str>,
|
custom_path: Option<&str>,
|
||||||
offline: bool,
|
offline: bool,
|
||||||
) -> Result<PathBuf, Box<dyn Error>> {
|
) -> Result<PathBuf, Box<dyn Error>> {
|
||||||
// If custom path provided, use it
|
|
||||||
if let Some(path) = custom_path {
|
if let Some(path) = custom_path {
|
||||||
let model_path = PathBuf::from(path);
|
let model_path = PathBuf::from(path);
|
||||||
if !model_path.exists() {
|
if !model_path.exists() {
|
||||||
|
|
@ -125,16 +101,14 @@ pub fn get_model_path(
|
||||||
return Ok(model_path);
|
return Ok(model_path);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check cache
|
|
||||||
let cache_dir = get_cache_dir()?;
|
let cache_dir = get_cache_dir()?;
|
||||||
let model_filename = format!("{}.onnx", model_info.name);
|
let model_filename = format!("{}.onnx", name);
|
||||||
let model_path = cache_dir.join(&model_filename);
|
let model_path = cache_dir.join(&model_filename);
|
||||||
|
|
||||||
if model_path.exists() {
|
if model_path.exists() {
|
||||||
println!("Using cached model: {}", model_path.display());
|
println!("Using cached model: {}", model_path.display());
|
||||||
|
|
||||||
// Verify hash if provided
|
if let Some(expected_hash) = sha256 {
|
||||||
if let Some(expected_hash) = model_info.sha256 {
|
|
||||||
print!("Verifying model integrity... ");
|
print!("Verifying model integrity... ");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
if verify_hash(&model_path, expected_hash)? {
|
if verify_hash(&model_path, expected_hash)? {
|
||||||
|
|
@ -150,7 +124,6 @@ pub fn get_model_path(
|
||||||
return Ok(model_path);
|
return Ok(model_path);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Download if not in offline mode
|
|
||||||
if offline {
|
if offline {
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"Model not found in cache and offline mode is enabled. Cache path: {}",
|
"Model not found in cache and offline mode is enabled. Cache path: {}",
|
||||||
|
|
@ -160,16 +133,14 @@ pub fn get_model_path(
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Model not found in cache, downloading...");
|
println!("Model not found in cache, downloading...");
|
||||||
download_file(model_info.url, &model_path)?;
|
download_file(url, &model_path)?;
|
||||||
|
|
||||||
// Verify after download
|
if let Some(expected_hash) = sha256 {
|
||||||
if let Some(expected_hash) = model_info.sha256 {
|
|
||||||
print!("Verifying downloaded model... ");
|
print!("Verifying downloaded model... ");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
if verify_hash(&model_path, expected_hash)? {
|
if verify_hash(&model_path, expected_hash)? {
|
||||||
println!("OK");
|
println!("OK");
|
||||||
} else {
|
} else {
|
||||||
// Delete corrupted file
|
|
||||||
fs::remove_file(&model_path)?;
|
fs::remove_file(&model_path)?;
|
||||||
return Err("Downloaded model hash verification failed".into());
|
return Err("Downloaded model hash verification failed".into());
|
||||||
}
|
}
|
||||||
|
|
@ -178,7 +149,7 @@ pub fn get_model_path(
|
||||||
Ok(model_path)
|
Ok(model_path)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an ONNX Runtime session from model path with CUDA backend
|
/// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU).
|
||||||
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
||||||
println!("Loading model into ONNX Runtime with CUDA backend...");
|
println!("Loading model into ONNX Runtime with CUDA backend...");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,53 +1,32 @@
|
||||||
use image::{DynamicImage, GenericImageView, Rgba, RgbaImage, imageops::FilterType};
|
use image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType};
|
||||||
use ndarray::Array4;
|
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
|
||||||
/// Apply mask to original image to remove background
|
/// Compose `original` with `mask` as the alpha channel and return an RGBA image.
|
||||||
///
|
///
|
||||||
/// Steps:
|
/// The mask is expected to already be grayscale. If its dimensions differ from
|
||||||
/// 1. Extract mask from output tensor (shape: [1, 1, H, W])
|
/// the original, it is resized with LANCZOS3.
|
||||||
/// 2. Resize mask to match original image dimensions
|
pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage, Box<dyn Error>> {
|
||||||
/// 3. Apply mask as alpha channel to create RGBA image
|
|
||||||
pub fn apply_mask(
|
|
||||||
original: &DynamicImage,
|
|
||||||
mask_tensor: Array4<f32>,
|
|
||||||
) -> Result<RgbaImage, Box<dyn Error>> {
|
|
||||||
println!("Applying mask to remove background...");
|
println!("Applying mask to remove background...");
|
||||||
|
|
||||||
let (orig_width, orig_height) = original.dimensions();
|
let (orig_width, orig_height) = original.dimensions();
|
||||||
|
let (mask_width, mask_height) = mask.dimensions();
|
||||||
// Extract mask dimensions
|
|
||||||
let mask_shape = mask_tensor.shape();
|
|
||||||
if mask_shape[0] != 1 || mask_shape[1] != 1 {
|
|
||||||
return Err(format!("Expected mask shape [1, 1, H, W], got {:?}", mask_shape).into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mask_height = mask_shape[2] as u32;
|
|
||||||
let mask_width = mask_shape[3] as u32;
|
|
||||||
|
|
||||||
println!("Mask dimensions: {}x{}", mask_width, mask_height);
|
println!("Mask dimensions: {}x{}", mask_width, mask_height);
|
||||||
println!("Original dimensions: {}x{}", orig_width, orig_height);
|
println!("Original dimensions: {}x{}", orig_width, orig_height);
|
||||||
|
|
||||||
// Create a grayscale image from the mask
|
let resized_mask: std::borrow::Cow<'_, GrayImage> =
|
||||||
let mut mask_image = image::GrayImage::new(mask_width, mask_height);
|
if mask_width != orig_width || mask_height != orig_height {
|
||||||
for y in 0..mask_height {
|
|
||||||
for x in 0..mask_width {
|
|
||||||
let mask_value = mask_tensor[[0, 0, y as usize, x as usize]];
|
|
||||||
// Clamp and convert to u8
|
|
||||||
let pixel_value = (mask_value.clamp(0.0, 1.0) * 255.0) as u8;
|
|
||||||
mask_image.put_pixel(x, y, image::Luma([pixel_value]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resize mask to match original image dimensions if needed
|
|
||||||
let resized_mask = if mask_width != orig_width || mask_height != orig_height {
|
|
||||||
println!("Resizing mask to match original image...");
|
println!("Resizing mask to match original image...");
|
||||||
image::imageops::resize(&mask_image, orig_width, orig_height, FilterType::Lanczos3)
|
std::borrow::Cow::Owned(image::imageops::resize(
|
||||||
|
mask,
|
||||||
|
orig_width,
|
||||||
|
orig_height,
|
||||||
|
FilterType::Lanczos3,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
mask_image
|
std::borrow::Cow::Borrowed(mask)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Convert original to RGBA and apply mask
|
|
||||||
let rgba_original = original.to_rgba8();
|
let rgba_original = original.to_rgba8();
|
||||||
let mut result = RgbaImage::new(orig_width, orig_height);
|
let mut result = RgbaImage::new(orig_width, orig_height);
|
||||||
|
|
||||||
|
|
@ -56,16 +35,10 @@ pub fn apply_mask(
|
||||||
let orig_pixel = rgba_original.get_pixel(x, y);
|
let orig_pixel = rgba_original.get_pixel(x, y);
|
||||||
let mask_pixel = resized_mask.get_pixel(x, y);
|
let mask_pixel = resized_mask.get_pixel(x, y);
|
||||||
|
|
||||||
// Apply mask as alpha channel
|
|
||||||
result.put_pixel(
|
result.put_pixel(
|
||||||
x,
|
x,
|
||||||
y,
|
y,
|
||||||
Rgba([
|
Rgba([orig_pixel[0], orig_pixel[1], orig_pixel[2], mask_pixel[0]]),
|
||||||
orig_pixel[0], // R
|
|
||||||
orig_pixel[1], // G
|
|
||||||
orig_pixel[2], // B
|
|
||||||
mask_pixel[0], // Alpha from mask
|
|
||||||
]),
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -83,7 +56,6 @@ pub fn create_side_by_side(
|
||||||
let (width, height) = original.dimensions();
|
let (width, height) = original.dimensions();
|
||||||
let mut composite = RgbaImage::new(width * 2, height);
|
let mut composite = RgbaImage::new(width * 2, height);
|
||||||
|
|
||||||
// Left side: original image
|
|
||||||
let original_rgba = original.to_rgba8();
|
let original_rgba = original.to_rgba8();
|
||||||
for y in 0..height {
|
for y in 0..height {
|
||||||
for x in 0..width {
|
for x in 0..width {
|
||||||
|
|
@ -97,12 +69,10 @@ pub fn create_side_by_side(
|
||||||
let result_pixel = result.get_pixel(x, y);
|
let result_pixel = result.get_pixel(x, y);
|
||||||
let alpha = result_pixel[3] as f32 / 255.0;
|
let alpha = result_pixel[3] as f32 / 255.0;
|
||||||
|
|
||||||
// Create checkered background (8x8 squares)
|
|
||||||
let checker_size = 8;
|
let checker_size = 8;
|
||||||
let is_light = ((x / checker_size) + (y / checker_size)) % 2 == 0;
|
let is_light = ((x / checker_size) + (y / checker_size)) % 2 == 0;
|
||||||
let bg_color = if is_light { 200 } else { 150 };
|
let bg_color = if is_light { 200 } else { 150 };
|
||||||
|
|
||||||
// Alpha blend with checkered background
|
|
||||||
let final_r = (result_pixel[0] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
|
let final_r = (result_pixel[0] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
|
||||||
let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
|
let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
|
||||||
let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
|
let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
|
||||||
|
|
@ -117,13 +87,16 @@ pub fn create_side_by_side(
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use ndarray::Array4;
|
use image::{GrayImage, Luma};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_apply_mask_shape() {
|
fn test_apply_mask_shape() {
|
||||||
let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100));
|
let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100));
|
||||||
let mask = Array4::<f32>::ones((1, 1, 100, 100));
|
let mut mask = GrayImage::new(100, 100);
|
||||||
let result = apply_mask(&img, mask).unwrap();
|
for p in mask.pixels_mut() {
|
||||||
|
*p = Luma([255]);
|
||||||
|
}
|
||||||
|
let result = apply_mask(&img, &mask).unwrap();
|
||||||
assert_eq!(result.dimensions(), (100, 100));
|
assert_eq!(result.dimensions(), (100, 100));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -132,6 +105,6 @@ mod tests {
|
||||||
let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100));
|
let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100));
|
||||||
let result_img = RgbaImage::new(100, 100);
|
let result_img = RgbaImage::new(100, 100);
|
||||||
let composite = create_side_by_side(&img, &result_img).unwrap();
|
let composite = create_side_by_side(&img, &result_img).unwrap();
|
||||||
assert_eq!(composite.dimensions(), (200, 100)); // Double width
|
assert_eq!(composite.dimensions(), (200, 100));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,67 +0,0 @@
|
||||||
use image::{DynamicImage, imageops::FilterType};
|
|
||||||
use ndarray::Array4;
|
|
||||||
use std::error::Error;
|
|
||||||
|
|
||||||
/// Preprocess an image for the BiRefNet model
|
|
||||||
///
|
|
||||||
/// Steps:
|
|
||||||
/// 1. Resize to target dimensions (1024x1024)
|
|
||||||
/// 2. Convert from u8 [0, 255] to f32 [0.0, 1.0]
|
|
||||||
/// 3. Rearrange from HWC (Height, Width, Channels) to CHW format
|
|
||||||
/// 4. Add batch dimension: [1, 3, H, W]
|
|
||||||
pub fn preprocess_image(
|
|
||||||
img: &DynamicImage,
|
|
||||||
target_width: u32,
|
|
||||||
target_height: u32,
|
|
||||||
) -> Result<Array4<f32>, Box<dyn Error>> {
|
|
||||||
println!("Preprocessing image...");
|
|
||||||
|
|
||||||
// Step 1: Resize image
|
|
||||||
let resized = img.resize_exact(target_width, target_height, FilterType::Lanczos3);
|
|
||||||
let rgb_image = resized.to_rgb8();
|
|
||||||
|
|
||||||
let (width, height) = rgb_image.dimensions();
|
|
||||||
|
|
||||||
// Step 2: Create ndarray with shape [1, 3, height, width]
|
|
||||||
let mut array = Array4::<f32>::zeros((1, 3, height as usize, width as usize));
|
|
||||||
|
|
||||||
// Step 3: Fill the array, converting from HWC to CHW and normalizing
|
|
||||||
for y in 0..height {
|
|
||||||
for x in 0..width {
|
|
||||||
let pixel = rgb_image.get_pixel(x, y);
|
|
||||||
|
|
||||||
// Normalize from [0, 255] to [0.0, 1.0]
|
|
||||||
array[[0, 0, y as usize, x as usize]] = pixel[0] as f32 / 255.0; // R
|
|
||||||
array[[0, 1, y as usize, x as usize]] = pixel[1] as f32 / 255.0; // G
|
|
||||||
array[[0, 2, y as usize, x as usize]] = pixel[2] as f32 / 255.0; // B
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("Preprocessing complete. Tensor shape: {:?}", array.shape());
|
|
||||||
|
|
||||||
Ok(array)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use image::RgbImage;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_preprocess_shape() {
|
|
||||||
let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100));
|
|
||||||
let result = preprocess_image(&img, 1024, 1024).unwrap();
|
|
||||||
assert_eq!(result.shape(), &[1, 3, 1024, 1024]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_preprocess_normalization() {
|
|
||||||
let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100));
|
|
||||||
let result = preprocess_image(&img, 1024, 1024).unwrap();
|
|
||||||
|
|
||||||
// Check that all values are in [0.0, 1.0]
|
|
||||||
for &val in result.iter() {
|
|
||||||
assert!(val >= 0.0 && val <= 1.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
54
src/sessions/birefnet_lite.rs
Normal file
54
src/sessions/birefnet_lite.rs
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
use ort::session::Session as OrtSession;
|
||||||
|
use std::error::Error;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use crate::model::{create_session, get_model_path};
|
||||||
|
|
||||||
|
use super::Session;
|
||||||
|
|
||||||
|
pub struct BiRefNetLiteSession {
|
||||||
|
inner_session: OrtSession,
|
||||||
|
input_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BiRefNetLiteSession {
|
||||||
|
pub fn new(offline: bool) -> Result<Self, Box<dyn Error>> {
|
||||||
|
let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?;
|
||||||
|
Self::from_model_path(&path)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_model_path(path: &Path) -> Result<Self, Box<dyn Error>> {
|
||||||
|
let inner_session = create_session(path)?;
|
||||||
|
let input_name = inner_session.inputs()[0].name().to_string();
|
||||||
|
Ok(Self {
|
||||||
|
inner_session,
|
||||||
|
input_name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Session for BiRefNetLiteSession {
|
||||||
|
fn name() -> &'static str {
|
||||||
|
"birefnet-general-lite"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn url() -> &'static str {
|
||||||
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sha256() -> Option<&'static str> {
|
||||||
|
Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_sigmoid(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inner(&mut self) -> &mut OrtSession {
|
||||||
|
&mut self.inner_session
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_name(&self) -> &str {
|
||||||
|
&self.input_name
|
||||||
|
}
|
||||||
|
}
|
||||||
50
src/sessions/bria.rs
Normal file
50
src/sessions/bria.rs
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
use ort::session::Session as OrtSession;
|
||||||
|
use std::error::Error;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use crate::model::{create_session, get_model_path};
|
||||||
|
|
||||||
|
use super::Session;
|
||||||
|
|
||||||
|
pub struct BriaSession {
|
||||||
|
inner_session: OrtSession,
|
||||||
|
input_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BriaSession {
|
||||||
|
pub fn new(offline: bool) -> Result<Self, Box<dyn Error>> {
|
||||||
|
let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?;
|
||||||
|
Self::from_model_path(&path)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_model_path(path: &Path) -> Result<Self, Box<dyn Error>> {
|
||||||
|
let inner_session = create_session(path)?;
|
||||||
|
let input_name = inner_session.inputs()[0].name().to_string();
|
||||||
|
Ok(Self {
|
||||||
|
inner_session,
|
||||||
|
input_name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Session for BriaSession {
|
||||||
|
fn name() -> &'static str {
|
||||||
|
"bria-rmbg"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn url() -> &'static str {
|
||||||
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sha256() -> Option<&'static str> {
|
||||||
|
Some("5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inner(&mut self) -> &mut OrtSession {
|
||||||
|
&mut self.inner_session
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_name(&self) -> &str {
|
||||||
|
&self.input_name
|
||||||
|
}
|
||||||
|
}
|
||||||
178
src/sessions/mod.rs
Normal file
178
src/sessions/mod.rs
Normal file
|
|
@ -0,0 +1,178 @@
|
||||||
|
use image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType};
|
||||||
|
use ndarray::{Array4, IntoDimension};
|
||||||
|
use ort::{session::Session as OrtSession, value::Tensor};
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
|
mod birefnet_lite;
|
||||||
|
mod bria;
|
||||||
|
|
||||||
|
pub use birefnet_lite::BiRefNetLiteSession;
|
||||||
|
pub use bria::BriaSession;
|
||||||
|
|
||||||
|
/// Common interface for background-removal models, mirroring rembg's
|
||||||
|
/// `BaseSession` pattern. Each concrete session owns an `ort::Session` and
|
||||||
|
/// implements `inner` / `input_name`; the rest (preprocessing, inference,
|
||||||
|
/// postprocessing into a mask) is provided by default implementations.
|
||||||
|
pub trait Session {
|
||||||
|
/// Canonical model name (matches rembg's filename stem).
|
||||||
|
fn name() -> &'static str
|
||||||
|
where
|
||||||
|
Self: Sized;
|
||||||
|
|
||||||
|
/// URL to download the ONNX model from.
|
||||||
|
fn url() -> &'static str
|
||||||
|
where
|
||||||
|
Self: Sized;
|
||||||
|
|
||||||
|
/// Optional SHA-256 checksum used to verify the cached model.
|
||||||
|
fn sha256() -> Option<&'static str>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mean(&self) -> (f32, f32, f32) {
|
||||||
|
(0.485, 0.456, 0.406)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn std(&self) -> (f32, f32, f32) {
|
||||||
|
(0.229, 0.224, 0.225)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_size(&self) -> (u32, u32) {
|
||||||
|
(1024, 1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether a sigmoid should be applied to the raw logits before the
|
||||||
|
/// min/max normalization step. `birefnet-*` needs this; `bria-rmbg` does not.
|
||||||
|
fn apply_sigmoid(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inner(&mut self) -> &mut OrtSession;
|
||||||
|
|
||||||
|
fn input_name(&self) -> &str;
|
||||||
|
|
||||||
|
/// Port of rembg's `BaseSession.normalize`: resize with LANCZOS,
|
||||||
|
/// scale into `[0, 1]` by dividing by the max pixel value, then apply
|
||||||
|
/// channel-wise mean/std.
|
||||||
|
fn normalize(&self, img: &DynamicImage) -> Result<Array4<f32>, Box<dyn Error>> {
|
||||||
|
let (w, h) = self.input_size();
|
||||||
|
let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_rgb8();
|
||||||
|
let (width, height) = resized.dimensions();
|
||||||
|
|
||||||
|
let mut max_pixel: f32 = 0.0;
|
||||||
|
for p in resized.pixels() {
|
||||||
|
for c in 0..3 {
|
||||||
|
let v = p[c] as f32;
|
||||||
|
if v > max_pixel {
|
||||||
|
max_pixel = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let denom = max_pixel.max(1e-6);
|
||||||
|
|
||||||
|
let mean = self.mean();
|
||||||
|
let std = self.std();
|
||||||
|
|
||||||
|
let mut array = Array4::<f32>::zeros((1, 3, height as usize, width as usize));
|
||||||
|
for y in 0..height {
|
||||||
|
for x in 0..width {
|
||||||
|
let pixel = resized.get_pixel(x, y);
|
||||||
|
let r = pixel[0] as f32 / denom;
|
||||||
|
let g = pixel[1] as f32 / denom;
|
||||||
|
let b = pixel[2] as f32 / denom;
|
||||||
|
array[[0, 0, y as usize, x as usize]] = (r - mean.0) / std.0;
|
||||||
|
array[[0, 1, y as usize, x as usize]] = (g - mean.1) / std.1;
|
||||||
|
array[[0, 2, y as usize, x as usize]] = (b - mean.2) / std.2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(array)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run inference and return a grayscale mask resized to the input image.
|
||||||
|
/// Mirrors rembg's `predict`:
|
||||||
|
/// 1. `inner_session.run(normalize(img, mean, std, size))`
|
||||||
|
/// 2. take the first output, channel 0
|
||||||
|
/// 3. optional sigmoid (birefnet)
|
||||||
|
/// 4. min/max normalize into `[0, 1]`
|
||||||
|
/// 5. scale to `u8`, resize to original image dimensions
|
||||||
|
fn predict(&mut self, img: &DynamicImage) -> Result<GrayImage, Box<dyn Error>> {
|
||||||
|
let (orig_w, orig_h) = img.dimensions();
|
||||||
|
let input = self.normalize(img)?;
|
||||||
|
let apply_sigmoid = self.apply_sigmoid();
|
||||||
|
|
||||||
|
let input_name = self.input_name().to_string();
|
||||||
|
let outputs = self
|
||||||
|
.inner()
|
||||||
|
.run(ort::inputs![input_name => Tensor::from_array(input)?])?;
|
||||||
|
|
||||||
|
let output = &outputs[0];
|
||||||
|
let (shape, data) = output.try_extract_tensor::<f32>()?;
|
||||||
|
|
||||||
|
if shape.len() != 4 {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected 4D output tensor [N, C, H, W], got shape {:?}",
|
||||||
|
shape
|
||||||
|
)
|
||||||
|
.into());
|
||||||
|
}
|
||||||
|
let (n, _c, h, w) = (
|
||||||
|
shape[0] as usize,
|
||||||
|
shape[1] as usize,
|
||||||
|
shape[2] as usize,
|
||||||
|
shape[3] as usize,
|
||||||
|
);
|
||||||
|
if n != 1 {
|
||||||
|
return Err(format!("Expected batch size 1, got {}", n).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let view = ndarray::ArrayView::from_shape(
|
||||||
|
(
|
||||||
|
shape[0] as usize,
|
||||||
|
shape[1] as usize,
|
||||||
|
shape[2] as usize,
|
||||||
|
shape[3] as usize,
|
||||||
|
)
|
||||||
|
.into_dimension(),
|
||||||
|
data,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Take channel 0: pred = out[:, 0, :, :]
|
||||||
|
let mut pred: Vec<f32> = Vec::with_capacity(h * w);
|
||||||
|
for y in 0..h {
|
||||||
|
for x in 0..w {
|
||||||
|
let mut v = view[[0, 0, y, x]];
|
||||||
|
if apply_sigmoid {
|
||||||
|
v = 1.0 / (1.0 + (-v).exp());
|
||||||
|
}
|
||||||
|
pred.push(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (mut mi, mut ma) = (f32::INFINITY, f32::NEG_INFINITY);
|
||||||
|
for &v in &pred {
|
||||||
|
if v < mi {
|
||||||
|
mi = v;
|
||||||
|
}
|
||||||
|
if v > ma {
|
||||||
|
ma = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let range = (ma - mi).max(1e-6);
|
||||||
|
|
||||||
|
let mut mask = GrayImage::new(w as u32, h as u32);
|
||||||
|
for y in 0..h {
|
||||||
|
for x in 0..w {
|
||||||
|
let v = (pred[y * w + x] - mi) / range;
|
||||||
|
let u = (v.clamp(0.0, 1.0) * 255.0).round() as u8;
|
||||||
|
mask.put_pixel(x as u32, y as u32, Luma([u]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mask = image::imageops::resize(&mask, orig_w, orig_h, FilterType::Lanczos3);
|
||||||
|
Ok(mask)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue