Compare commits

..

No commits in common. "395990e47ddc7a7df547473ee30d5f6cc0752e3c" and "b979fadb434937de36a2ad184ff45118b9fc1bad" have entirely different histories.

10 changed files with 240 additions and 373 deletions

View file

@ -20,11 +20,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1776548001,
"narHash": "sha256-ZSK0NL4a1BwVbbTBoSnWgbJy9HeZFXLYQizjb2DPF24=",
"lastModified": 1769789167,
"narHash": "sha256-kKB3bqYJU5nzYeIROI82Ef9VtTbu4uA3YydSk/Bioa8=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "b12141ef619e0a9c1c84dc8c684040326f27cdcc",
"rev": "62c8382960464ceb98ea593cb8321a2cf8f9e3e5",
"type": "github"
},
"original": {
@ -62,11 +62,11 @@
"nixpkgs": "nixpkgs_2"
},
"locked": {
"lastModified": 1776827647,
"narHash": "sha256-sYixYhp5V8jCajO8TRorE4fzs7IkL4MZdfLTKgkPQBk=",
"lastModified": 1769915446,
"narHash": "sha256-f1F/umtX3ZD7fF9DHSloVHc0mnAT0ry0YK2jI/6E0aI=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "40e6ccc06e1245a4837cbbd6bdda64e21cc67379",
"rev": "bc00300f010275e46feb3c3974df6587ff7b7808",
"type": "github"
},
"original": {

View file

@ -32,10 +32,10 @@
];
xorgBuildInputs = with pkgs; [
xorg.libx11
xorg.libxcursor
xorg.libxi
xorg.libxrandr
xorg.libX11
xorg.libXcursor
xorg.libXi
xorg.libXrandr
];
waylandBuildInputs = with pkgs; [

View file

@ -1,3 +0,0 @@
pub mod model;
pub mod postprocessing;
pub mod sessions;

View file

@ -1,13 +1,17 @@
use clap::Parser;
use image::GenericImageView;
use ndarray::IntoDimension;
use ort::value::Tensor;
use show_image::{AsImageView, create_window, event};
use std::error::Error;
use remove_background::{
model::Model,
postprocessing::{apply_mask, create_side_by_side},
sessions::{BiRefNetLiteSession, BriaSession, Session},
};
mod model;
mod postprocessing;
mod preprocessing;
use model::{Model, create_session, get_model_path};
use postprocessing::{apply_mask, create_side_by_side};
use preprocessing::preprocess_image;
#[derive(Parser)]
#[command(name = "remove_background")]
@ -35,11 +39,13 @@ struct Args {
#[show_image::main]
fn main() -> Result<(), Box<dyn Error>> {
// Parse command line arguments
let args = Args::parse();
println!("=== Background Removal Tool ===\n");
println!("Downloading image from: {}", args.url);
// Download the image with a user agent (using blocking client)
let client = reqwest::blocking::Client::builder()
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
.build()?;
@ -52,47 +58,69 @@ fn main() -> Result<(), Box<dyn Error>> {
let bytes = response.bytes()?;
println!("Downloaded {} bytes", bytes.len());
// Load the image
let img = image::load_from_memory(&bytes)?;
let (img_width, img_height) = img.dimensions();
println!("Loaded image: {}x{}\n", img_width, img_height);
let mut session: Box<dyn Session> = if let Some(custom_path) = args.model_path.as_deref() {
let path = std::path::PathBuf::from(custom_path);
if !path.exists() {
// Get model path - either from custom path or by selecting built-in model
let model_path = if let Some(custom_path) = args.model_path {
// Custom model path provided, use it directly
let model_path = std::path::PathBuf::from(&custom_path);
if !model_path.exists() {
return Err(format!("Custom model path does not exist: {}", custom_path).into());
}
println!("Using custom model: {}", custom_path);
match args.model {
Model::Bria => Box::new(BriaSession::from_model_path(&path)?),
Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?),
}
model_path
} else {
match args.model {
Model::Bria => {
println!("Using model: bria-rmbg");
Box::new(BriaSession::new(args.offline)?)
}
Model::BiRefNetLite => {
println!("Using model: birefnet-general-lite");
Box::new(BiRefNetLiteSession::new(args.offline)?)
}
}
// Use built-in model
let model_info = args.model.info();
println!("Using model: {}", model_info.name);
get_model_path(&model_info, None, args.offline)?
};
// Create ONNX Runtime session
let mut session = create_session(&model_path)?;
let input_name = session.inputs()[0].name().to_string();
// Preprocess
let input_tensor = preprocess_image(&img, 1024, 1024)?;
// Run inference
println!("Running inference...");
let mask = session.predict(&img)?;
let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?;
// Extract mask output
let mask_output = &outputs[0];
let (mask_shape, mask_array) = mask_output.try_extract_tensor::<f32>()?; // Returns (shape, &[f32])
// Convert the slice to Array4
let mask_tensor = ndarray::ArrayView::from_shape(
(
mask_shape[0] as usize,
mask_shape[1] as usize,
mask_shape[2] as usize,
mask_shape[3] as usize,
)
.into_dimension(),
mask_array,
)?
.to_owned();
println!(
"Inference complete! Mask dimensions: {}x{}",
mask.dimensions().0,
mask.dimensions().1
"Inference complete! Output shape: {:?}",
mask_tensor.shape()
);
let result_rgba = apply_mask(&img, &mask)?;
// Apply mask to remove background
let result_rgba = apply_mask(&img, mask_tensor)?;
// Create side-by-side comparison
println!("Creating side-by-side comparison...");
let composite = create_side_by_side(&img, &result_rgba)?;
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
// Display the result
let (comp_width, comp_height) = composite_dynamic.dimensions();
let window = create_window(
"Background Removal - Original (Left) vs Result (Right) - Press ESC to close",
@ -100,7 +128,7 @@ fn main() -> Result<(), Box<dyn Error>> {
)?;
window.set_image(
"comparison",
composite_dynamic
&composite_dynamic
.as_image_view()
.map_err(|e| e.to_string())?,
)?;
@ -113,16 +141,17 @@ fn main() -> Result<(), Box<dyn Error>> {
println!(" Left: Original image");
println!(" Right: Background removed (shown on checkered background)");
println!("\nPress ESC to close the window.");
// Event loop - wait for ESC key to close
for event in window.event_channel()? {
if let event::WindowEvent::KeyboardInput(event) = event
&& event.input.key_code == Some(event::VirtualKeyCode::Escape)
if let event::WindowEvent::KeyboardInput(event) = event {
if event.input.key_code == Some(event::VirtualKeyCode::Escape)
&& event.input.state.is_pressed()
{
println!("ESC pressed, closing...");
break;
}
}
}
Ok(())
}

View file

@ -7,8 +7,6 @@ use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
/// CLI-facing model selector. Concrete session metadata (URL, checksum,
/// preprocessing params) lives on the `Session` trait impls in `sessions/`.
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
#[clap(rename_all = "kebab-case")]
pub enum Model {
@ -16,6 +14,37 @@ pub enum Model {
Bria,
}
impl Model {
pub fn info(&self) -> &ModelInfo {
match self {
Model::BiRefNetLite => &ModelInfo::BIREFNET_LITE,
Model::Bria => &ModelInfo::BRIA,
}
}
}
/// Model metadata
pub struct ModelInfo {
pub name: &'static str,
pub url: &'static str,
pub sha256: Option<&'static str>,
}
impl ModelInfo {
/// BiRefNet General Lite model
pub const BIREFNET_LITE: ModelInfo = ModelInfo {
name: "birefnet-general-lite",
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
sha256: Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333"),
};
pub const BRIA: ModelInfo = ModelInfo {
name: "bria",
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx",
sha256: Some("5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958"),
};
}
/// Get the cache directory for models
fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"))?;
@ -27,6 +56,7 @@ fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
Ok(cache_dir)
}
/// Download a file from URL to destination
fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
println!("Downloading model from {}...", url);
@ -69,6 +99,7 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
Ok(())
}
/// Verify file SHA256 hash
fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Error>> {
let mut file = fs::File::open(file_path)?;
let mut hasher = Sha256::new();
@ -78,20 +109,13 @@ fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Er
Ok(hash_str == expected_hash)
}
/// Resolve a model file to a local path, downloading + verifying if needed.
///
/// - `name`: cache filename stem (e.g. `"bria-rmbg"`)
/// - `url`: remote download URL
/// - `sha256`: optional expected SHA-256 of the ONNX file
/// - `custom_path`: if `Some`, bypass the cache and use this path directly
/// - `offline`: if true, fail instead of downloading when the cache is cold
/// Get or download model, return path to model file
pub fn get_model_path(
name: &str,
url: &str,
sha256: Option<&str>,
model_info: &ModelInfo,
custom_path: Option<&str>,
offline: bool,
) -> Result<PathBuf, Box<dyn Error>> {
// If custom path provided, use it
if let Some(path) = custom_path {
let model_path = PathBuf::from(path);
if !model_path.exists() {
@ -101,14 +125,16 @@ pub fn get_model_path(
return Ok(model_path);
}
// Check cache
let cache_dir = get_cache_dir()?;
let model_filename = format!("{}.onnx", name);
let model_filename = format!("{}.onnx", model_info.name);
let model_path = cache_dir.join(&model_filename);
if model_path.exists() {
println!("Using cached model: {}", model_path.display());
if let Some(expected_hash) = sha256 {
// Verify hash if provided
if let Some(expected_hash) = model_info.sha256 {
print!("Verifying model integrity... ");
std::io::stdout().flush()?;
if verify_hash(&model_path, expected_hash)? {
@ -124,6 +150,7 @@ pub fn get_model_path(
return Ok(model_path);
}
// Download if not in offline mode
if offline {
return Err(format!(
"Model not found in cache and offline mode is enabled. Cache path: {}",
@ -133,14 +160,16 @@ pub fn get_model_path(
}
println!("Model not found in cache, downloading...");
download_file(url, &model_path)?;
download_file(model_info.url, &model_path)?;
if let Some(expected_hash) = sha256 {
// Verify after download
if let Some(expected_hash) = model_info.sha256 {
print!("Verifying downloaded model... ");
std::io::stdout().flush()?;
if verify_hash(&model_path, expected_hash)? {
println!("OK");
} else {
// Delete corrupted file
fs::remove_file(&model_path)?;
return Err("Downloaded model hash verification failed".into());
}
@ -149,7 +178,7 @@ pub fn get_model_path(
Ok(model_path)
}
/// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU).
/// Create an ONNX Runtime session from model path with CUDA backend
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
println!("Loading model into ONNX Runtime with CUDA backend...");

View file

@ -1,32 +1,53 @@
use image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType};
use image::{DynamicImage, GenericImageView, Rgba, RgbaImage, imageops::FilterType};
use ndarray::Array4;
use std::error::Error;
/// Compose `original` with `mask` as the alpha channel and return an RGBA image.
/// Apply mask to original image to remove background
///
/// The mask is expected to already be grayscale. If its dimensions differ from
/// the original, it is resized with LANCZOS3.
pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage, Box<dyn Error>> {
/// Steps:
/// 1. Extract mask from output tensor (shape: [1, 1, H, W])
/// 2. Resize mask to match original image dimensions
/// 3. Apply mask as alpha channel to create RGBA image
pub fn apply_mask(
original: &DynamicImage,
mask_tensor: Array4<f32>,
) -> Result<RgbaImage, Box<dyn Error>> {
println!("Applying mask to remove background...");
let (orig_width, orig_height) = original.dimensions();
let (mask_width, mask_height) = mask.dimensions();
// Extract mask dimensions
let mask_shape = mask_tensor.shape();
if mask_shape[0] != 1 || mask_shape[1] != 1 {
return Err(format!("Expected mask shape [1, 1, H, W], got {:?}", mask_shape).into());
}
let mask_height = mask_shape[2] as u32;
let mask_width = mask_shape[3] as u32;
println!("Mask dimensions: {}x{}", mask_width, mask_height);
println!("Original dimensions: {}x{}", orig_width, orig_height);
let resized_mask: std::borrow::Cow<'_, GrayImage> =
if mask_width != orig_width || mask_height != orig_height {
// Create a grayscale image from the mask
let mut mask_image = image::GrayImage::new(mask_width, mask_height);
for y in 0..mask_height {
for x in 0..mask_width {
let mask_value = mask_tensor[[0, 0, y as usize, x as usize]];
// Clamp and convert to u8
let pixel_value = (mask_value.clamp(0.0, 1.0) * 255.0) as u8;
mask_image.put_pixel(x, y, image::Luma([pixel_value]));
}
}
// Resize mask to match original image dimensions if needed
let resized_mask = if mask_width != orig_width || mask_height != orig_height {
println!("Resizing mask to match original image...");
std::borrow::Cow::Owned(image::imageops::resize(
mask,
orig_width,
orig_height,
FilterType::Lanczos3,
))
image::imageops::resize(&mask_image, orig_width, orig_height, FilterType::Lanczos3)
} else {
std::borrow::Cow::Borrowed(mask)
mask_image
};
// Convert original to RGBA and apply mask
let rgba_original = original.to_rgba8();
let mut result = RgbaImage::new(orig_width, orig_height);
@ -35,10 +56,16 @@ pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage
let orig_pixel = rgba_original.get_pixel(x, y);
let mask_pixel = resized_mask.get_pixel(x, y);
// Apply mask as alpha channel
result.put_pixel(
x,
y,
Rgba([orig_pixel[0], orig_pixel[1], orig_pixel[2], mask_pixel[0]]),
Rgba([
orig_pixel[0], // R
orig_pixel[1], // G
orig_pixel[2], // B
mask_pixel[0], // Alpha from mask
]),
);
}
}
@ -56,6 +83,7 @@ pub fn create_side_by_side(
let (width, height) = original.dimensions();
let mut composite = RgbaImage::new(width * 2, height);
// Left side: original image
let original_rgba = original.to_rgba8();
for y in 0..height {
for x in 0..width {
@ -69,10 +97,12 @@ pub fn create_side_by_side(
let result_pixel = result.get_pixel(x, y);
let alpha = result_pixel[3] as f32 / 255.0;
// Create checkered background (8x8 squares)
let checker_size = 8;
let is_light = ((x / checker_size) + (y / checker_size)) % 2 == 0;
let bg_color = if is_light { 200 } else { 150 };
// Alpha blend with checkered background
let final_r = (result_pixel[0] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
@ -87,16 +117,13 @@ pub fn create_side_by_side(
#[cfg(test)]
mod tests {
use super::*;
use image::{GrayImage, Luma};
use ndarray::Array4;
#[test]
fn test_apply_mask_shape() {
let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100));
let mut mask = GrayImage::new(100, 100);
for p in mask.pixels_mut() {
*p = Luma([255]);
}
let result = apply_mask(&img, &mask).unwrap();
let mask = Array4::<f32>::ones((1, 1, 100, 100));
let result = apply_mask(&img, mask).unwrap();
assert_eq!(result.dimensions(), (100, 100));
}
@ -105,6 +132,6 @@ mod tests {
let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100));
let result_img = RgbaImage::new(100, 100);
let composite = create_side_by_side(&img, &result_img).unwrap();
assert_eq!(composite.dimensions(), (200, 100));
assert_eq!(composite.dimensions(), (200, 100)); // Double width
}
}

67
src/preprocessing.rs Normal file
View file

@ -0,0 +1,67 @@
use image::{DynamicImage, imageops::FilterType};
use ndarray::Array4;
use std::error::Error;
/// Preprocess an image for the BiRefNet model
///
/// Steps:
/// 1. Resize to target dimensions (1024x1024)
/// 2. Convert from u8 [0, 255] to f32 [0.0, 1.0]
/// 3. Rearrange from HWC (Height, Width, Channels) to CHW format
/// 4. Add batch dimension: [1, 3, H, W]
pub fn preprocess_image(
img: &DynamicImage,
target_width: u32,
target_height: u32,
) -> Result<Array4<f32>, Box<dyn Error>> {
println!("Preprocessing image...");
// Step 1: Resize image
let resized = img.resize_exact(target_width, target_height, FilterType::Lanczos3);
let rgb_image = resized.to_rgb8();
let (width, height) = rgb_image.dimensions();
// Step 2: Create ndarray with shape [1, 3, height, width]
let mut array = Array4::<f32>::zeros((1, 3, height as usize, width as usize));
// Step 3: Fill the array, converting from HWC to CHW and normalizing
for y in 0..height {
for x in 0..width {
let pixel = rgb_image.get_pixel(x, y);
// Normalize from [0, 255] to [0.0, 1.0]
array[[0, 0, y as usize, x as usize]] = pixel[0] as f32 / 255.0; // R
array[[0, 1, y as usize, x as usize]] = pixel[1] as f32 / 255.0; // G
array[[0, 2, y as usize, x as usize]] = pixel[2] as f32 / 255.0; // B
}
}
println!("Preprocessing complete. Tensor shape: {:?}", array.shape());
Ok(array)
}
#[cfg(test)]
mod tests {
use super::*;
use image::RgbImage;
#[test]
fn test_preprocess_shape() {
let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100));
let result = preprocess_image(&img, 1024, 1024).unwrap();
assert_eq!(result.shape(), &[1, 3, 1024, 1024]);
}
#[test]
fn test_preprocess_normalization() {
let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100));
let result = preprocess_image(&img, 1024, 1024).unwrap();
// Check that all values are in [0.0, 1.0]
for &val in result.iter() {
assert!(val >= 0.0 && val <= 1.0);
}
}
}

View file

@ -1,54 +0,0 @@
use ort::session::Session as OrtSession;
use std::error::Error;
use std::path::Path;
use crate::model::{create_session, get_model_path};
use super::Session;
pub struct BiRefNetLiteSession {
inner_session: OrtSession,
input_name: String,
}
impl BiRefNetLiteSession {
pub fn new(offline: bool) -> Result<Self, Box<dyn Error>> {
let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?;
Self::from_model_path(&path)
}
pub fn from_model_path(path: &Path) -> Result<Self, Box<dyn Error>> {
let inner_session = create_session(path)?;
let input_name = inner_session.inputs()[0].name().to_string();
Ok(Self {
inner_session,
input_name,
})
}
}
impl Session for BiRefNetLiteSession {
fn name() -> &'static str {
"birefnet-general-lite"
}
fn url() -> &'static str {
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx"
}
fn sha256() -> Option<&'static str> {
Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333")
}
fn apply_sigmoid(&self) -> bool {
true
}
fn inner(&mut self) -> &mut OrtSession {
&mut self.inner_session
}
fn input_name(&self) -> &str {
&self.input_name
}
}

View file

@ -1,50 +0,0 @@
use ort::session::Session as OrtSession;
use std::error::Error;
use std::path::Path;
use crate::model::{create_session, get_model_path};
use super::Session;
pub struct BriaSession {
inner_session: OrtSession,
input_name: String,
}
impl BriaSession {
pub fn new(offline: bool) -> Result<Self, Box<dyn Error>> {
let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?;
Self::from_model_path(&path)
}
pub fn from_model_path(path: &Path) -> Result<Self, Box<dyn Error>> {
let inner_session = create_session(path)?;
let input_name = inner_session.inputs()[0].name().to_string();
Ok(Self {
inner_session,
input_name,
})
}
}
impl Session for BriaSession {
fn name() -> &'static str {
"bria-rmbg"
}
fn url() -> &'static str {
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx"
}
fn sha256() -> Option<&'static str> {
Some("5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958")
}
fn inner(&mut self) -> &mut OrtSession {
&mut self.inner_session
}
fn input_name(&self) -> &str {
&self.input_name
}
}

View file

@ -1,178 +0,0 @@
use image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType};
use ndarray::{Array4, IntoDimension};
use ort::{session::Session as OrtSession, value::Tensor};
use std::error::Error;
mod birefnet_lite;
mod bria;
pub use birefnet_lite::BiRefNetLiteSession;
pub use bria::BriaSession;
/// Common interface for background-removal models, mirroring rembg's
/// `BaseSession` pattern. Each concrete session owns an `ort::Session` and
/// implements `inner` / `input_name`; the rest (preprocessing, inference,
/// postprocessing into a mask) is provided by default implementations.
pub trait Session {
/// Canonical model name (matches rembg's filename stem).
fn name() -> &'static str
where
Self: Sized;
/// URL to download the ONNX model from.
fn url() -> &'static str
where
Self: Sized;
/// Optional SHA-256 checksum used to verify the cached model.
fn sha256() -> Option<&'static str>
where
Self: Sized,
{
None
}
fn mean(&self) -> (f32, f32, f32) {
(0.485, 0.456, 0.406)
}
fn std(&self) -> (f32, f32, f32) {
(0.229, 0.224, 0.225)
}
fn input_size(&self) -> (u32, u32) {
(1024, 1024)
}
/// Whether a sigmoid should be applied to the raw logits before the
/// min/max normalization step. `birefnet-*` needs this; `bria-rmbg` does not.
fn apply_sigmoid(&self) -> bool {
false
}
fn inner(&mut self) -> &mut OrtSession;
fn input_name(&self) -> &str;
/// Port of rembg's `BaseSession.normalize`: resize with LANCZOS,
/// scale into `[0, 1]` by dividing by the max pixel value, then apply
/// channel-wise mean/std.
fn normalize(&self, img: &DynamicImage) -> Result<Array4<f32>, Box<dyn Error>> {
let (w, h) = self.input_size();
let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_rgb8();
let (width, height) = resized.dimensions();
let mut max_pixel: f32 = 0.0;
for p in resized.pixels() {
for c in 0..3 {
let v = p[c] as f32;
if v > max_pixel {
max_pixel = v;
}
}
}
let denom = max_pixel.max(1e-6);
let mean = self.mean();
let std = self.std();
let mut array = Array4::<f32>::zeros((1, 3, height as usize, width as usize));
for y in 0..height {
for x in 0..width {
let pixel = resized.get_pixel(x, y);
let r = pixel[0] as f32 / denom;
let g = pixel[1] as f32 / denom;
let b = pixel[2] as f32 / denom;
array[[0, 0, y as usize, x as usize]] = (r - mean.0) / std.0;
array[[0, 1, y as usize, x as usize]] = (g - mean.1) / std.1;
array[[0, 2, y as usize, x as usize]] = (b - mean.2) / std.2;
}
}
Ok(array)
}
/// Run inference and return a grayscale mask resized to the input image.
/// Mirrors rembg's `predict`:
/// 1. `inner_session.run(normalize(img, mean, std, size))`
/// 2. take the first output, channel 0
/// 3. optional sigmoid (birefnet)
/// 4. min/max normalize into `[0, 1]`
/// 5. scale to `u8`, resize to original image dimensions
fn predict(&mut self, img: &DynamicImage) -> Result<GrayImage, Box<dyn Error>> {
let (orig_w, orig_h) = img.dimensions();
let input = self.normalize(img)?;
let apply_sigmoid = self.apply_sigmoid();
let input_name = self.input_name().to_string();
let outputs = self
.inner()
.run(ort::inputs![input_name => Tensor::from_array(input)?])?;
let output = &outputs[0];
let (shape, data) = output.try_extract_tensor::<f32>()?;
if shape.len() != 4 {
return Err(format!(
"Expected 4D output tensor [N, C, H, W], got shape {:?}",
shape
)
.into());
}
let (n, _c, h, w) = (
shape[0] as usize,
shape[1] as usize,
shape[2] as usize,
shape[3] as usize,
);
if n != 1 {
return Err(format!("Expected batch size 1, got {}", n).into());
}
let view = ndarray::ArrayView::from_shape(
(
shape[0] as usize,
shape[1] as usize,
shape[2] as usize,
shape[3] as usize,
)
.into_dimension(),
data,
)?;
// Take channel 0: pred = out[:, 0, :, :]
let mut pred: Vec<f32> = Vec::with_capacity(h * w);
for y in 0..h {
for x in 0..w {
let mut v = view[[0, 0, y, x]];
if apply_sigmoid {
v = 1.0 / (1.0 + (-v).exp());
}
pred.push(v);
}
}
let (mut mi, mut ma) = (f32::INFINITY, f32::NEG_INFINITY);
for &v in &pred {
if v < mi {
mi = v;
}
if v > ma {
ma = v;
}
}
let range = (ma - mi).max(1e-6);
let mut mask = GrayImage::new(w as u32, h as u32);
for y in 0..h {
for x in 0..w {
let v = (pred[y * w + x] - mi) / range;
let u = (v.clamp(0.0, 1.0) * 255.0).round() as u8;
mask.put_pixel(x as u32, y as u32, Luma([u]));
}
}
let mask = image::imageops::resize(&mask, orig_w, orig_h, FilterType::Lanczos3);
Ok(mask)
}
}