diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..9971b51 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,3 @@ +pub mod model; +pub mod postprocessing; +pub mod sessions; diff --git a/src/main.rs b/src/main.rs index a816e59..ca9ede1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,13 @@ use clap::Parser; use image::GenericImageView; -use ndarray::IntoDimension; -use ort::value::Tensor; use show_image::{AsImageView, create_window, event}; use std::error::Error; -mod model; -mod postprocessing; -mod preprocessing; - -use model::{Model, create_session, get_model_path}; -use postprocessing::{apply_mask, create_side_by_side}; -use preprocessing::preprocess_image; +use remove_background::{ + model::Model, + postprocessing::{apply_mask, create_side_by_side}, + sessions::{BiRefNetLiteSession, BriaSession, Session}, +}; #[derive(Parser)] #[command(name = "remove_background")] @@ -39,13 +35,11 @@ struct Args { #[show_image::main] fn main() -> Result<(), Box> { - // Parse command line arguments let args = Args::parse(); println!("=== Background Removal Tool ===\n"); println!("Downloading image from: {}", args.url); - // Download the image with a user agent (using blocking client) let client = reqwest::blocking::Client::builder() .user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36") .build()?; @@ -58,69 +52,47 @@ fn main() -> Result<(), Box> { let bytes = response.bytes()?; println!("Downloaded {} bytes", bytes.len()); - // Load the image let img = image::load_from_memory(&bytes)?; let (img_width, img_height) = img.dimensions(); println!("Loaded image: {}x{}\n", img_width, img_height); - // Get model path - either from custom path or by selecting built-in model - let model_path = if let Some(custom_path) = args.model_path { - // Custom model path provided, use it directly - let model_path = std::path::PathBuf::from(&custom_path); - if !model_path.exists() { + let mut session: Box = if let Some(custom_path) = args.model_path.as_deref() { + let path = std::path::PathBuf::from(custom_path); + if !path.exists() { return Err(format!("Custom model path does not exist: {}", custom_path).into()); } println!("Using custom model: {}", custom_path); - model_path + match args.model { + Model::Bria => Box::new(BriaSession::from_model_path(&path)?), + Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?), + } } else { - // Use built-in model - let model_info = args.model.info(); - println!("Using model: {}", model_info.name); - get_model_path(&model_info, None, args.offline)? + match args.model { + Model::Bria => { + println!("Using model: bria-rmbg"); + Box::new(BriaSession::new(args.offline)?) + } + Model::BiRefNetLite => { + println!("Using model: birefnet-general-lite"); + Box::new(BiRefNetLiteSession::new(args.offline)?) + } + } }; - // Create ONNX Runtime session - let mut session = create_session(&model_path)?; - let input_name = session.inputs()[0].name().to_string(); - - // Preprocess - let input_tensor = preprocess_image(&img, 1024, 1024)?; - - // Run inference println!("Running inference..."); - let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?; - - // Extract mask output - let mask_output = &outputs[0]; - let (mask_shape, mask_array) = mask_output.try_extract_tensor::()?; // Returns (shape, &[f32]) - - // Convert the slice to Array4 - let mask_tensor = ndarray::ArrayView::from_shape( - ( - mask_shape[0] as usize, - mask_shape[1] as usize, - mask_shape[2] as usize, - mask_shape[3] as usize, - ) - .into_dimension(), - mask_array, - )? - .to_owned(); - + let mask = session.predict(&img)?; println!( - "Inference complete! Output shape: {:?}", - mask_tensor.shape() + "Inference complete! Mask dimensions: {}x{}", + mask.dimensions().0, + mask.dimensions().1 ); - // Apply mask to remove background - let result_rgba = apply_mask(&img, mask_tensor)?; + let result_rgba = apply_mask(&img, &mask)?; - // Create side-by-side comparison println!("Creating side-by-side comparison..."); let composite = create_side_by_side(&img, &result_rgba)?; let composite_dynamic = image::DynamicImage::ImageRgba8(composite); - // Display the result let (comp_width, comp_height) = composite_dynamic.dimensions(); let window = create_window( "Background Removal - Original (Left) vs Result (Right) - Press ESC to close", @@ -128,7 +100,7 @@ fn main() -> Result<(), Box> { )?; window.set_image( "comparison", - &composite_dynamic + composite_dynamic .as_image_view() .map_err(|e| e.to_string())?, )?; @@ -141,15 +113,14 @@ fn main() -> Result<(), Box> { println!(" Left: Original image"); println!(" Right: Background removed (shown on checkered background)"); println!("\nPress ESC to close the window."); - // Event loop - wait for ESC key to close + for event in window.event_channel()? { - if let event::WindowEvent::KeyboardInput(event) = event { - if event.input.key_code == Some(event::VirtualKeyCode::Escape) - && event.input.state.is_pressed() - { - println!("ESC pressed, closing..."); - break; - } + if let event::WindowEvent::KeyboardInput(event) = event + && event.input.key_code == Some(event::VirtualKeyCode::Escape) + && event.input.state.is_pressed() + { + println!("ESC pressed, closing..."); + break; } } diff --git a/src/model.rs b/src/model.rs index ad24bff..fac8412 100644 --- a/src/model.rs +++ b/src/model.rs @@ -7,6 +7,8 @@ use std::fs; use std::io::Write; use std::path::{Path, PathBuf}; +/// CLI-facing model selector. Concrete session metadata (URL, checksum, +/// preprocessing params) lives on the `Session` trait impls in `sessions/`. #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] #[clap(rename_all = "kebab-case")] pub enum Model { @@ -14,37 +16,6 @@ pub enum Model { Bria, } -impl Model { - pub fn info(&self) -> &ModelInfo { - match self { - Model::BiRefNetLite => &ModelInfo::BIREFNET_LITE, - Model::Bria => &ModelInfo::BRIA, - } - } -} - -/// Model metadata -pub struct ModelInfo { - pub name: &'static str, - pub url: &'static str, - pub sha256: Option<&'static str>, -} - -impl ModelInfo { - /// BiRefNet General Lite model - pub const BIREFNET_LITE: ModelInfo = ModelInfo { - name: "birefnet-general-lite", - url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx", - sha256: Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333"), - }; - - pub const BRIA: ModelInfo = ModelInfo { - name: "bria", - url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx", - sha256: Some("5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958"), - }; -} - /// Get the cache directory for models fn get_cache_dir() -> Result> { let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"))?; @@ -56,7 +27,6 @@ fn get_cache_dir() -> Result> { Ok(cache_dir) } -/// Download a file from URL to destination fn download_file(url: &str, dest: &Path) -> Result<(), Box> { println!("Downloading model from {}...", url); @@ -99,7 +69,6 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box> { Ok(()) } -/// Verify file SHA256 hash fn verify_hash(file_path: &Path, expected_hash: &str) -> Result> { let mut file = fs::File::open(file_path)?; let mut hasher = Sha256::new(); @@ -109,13 +78,20 @@ fn verify_hash(file_path: &Path, expected_hash: &str) -> Result, custom_path: Option<&str>, offline: bool, ) -> Result> { - // If custom path provided, use it if let Some(path) = custom_path { let model_path = PathBuf::from(path); if !model_path.exists() { @@ -125,16 +101,14 @@ pub fn get_model_path( return Ok(model_path); } - // Check cache let cache_dir = get_cache_dir()?; - let model_filename = format!("{}.onnx", model_info.name); + let model_filename = format!("{}.onnx", name); let model_path = cache_dir.join(&model_filename); if model_path.exists() { println!("Using cached model: {}", model_path.display()); - // Verify hash if provided - if let Some(expected_hash) = model_info.sha256 { + if let Some(expected_hash) = sha256 { print!("Verifying model integrity... "); std::io::stdout().flush()?; if verify_hash(&model_path, expected_hash)? { @@ -150,7 +124,6 @@ pub fn get_model_path( return Ok(model_path); } - // Download if not in offline mode if offline { return Err(format!( "Model not found in cache and offline mode is enabled. Cache path: {}", @@ -160,16 +133,14 @@ pub fn get_model_path( } println!("Model not found in cache, downloading..."); - download_file(model_info.url, &model_path)?; + download_file(url, &model_path)?; - // Verify after download - if let Some(expected_hash) = model_info.sha256 { + if let Some(expected_hash) = sha256 { print!("Verifying downloaded model... "); std::io::stdout().flush()?; if verify_hash(&model_path, expected_hash)? { println!("OK"); } else { - // Delete corrupted file fs::remove_file(&model_path)?; return Err("Downloaded model hash verification failed".into()); } @@ -178,7 +149,7 @@ pub fn get_model_path( Ok(model_path) } -/// Create an ONNX Runtime session from model path with CUDA backend +/// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU). pub fn create_session(model_path: &Path) -> Result> { println!("Loading model into ONNX Runtime with CUDA backend..."); diff --git a/src/postprocessing.rs b/src/postprocessing.rs index 8865a3f..881a7cb 100644 --- a/src/postprocessing.rs +++ b/src/postprocessing.rs @@ -1,53 +1,32 @@ -use image::{DynamicImage, GenericImageView, Rgba, RgbaImage, imageops::FilterType}; -use ndarray::Array4; +use image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType}; use std::error::Error; -/// Apply mask to original image to remove background +/// Compose `original` with `mask` as the alpha channel and return an RGBA image. /// -/// Steps: -/// 1. Extract mask from output tensor (shape: [1, 1, H, W]) -/// 2. Resize mask to match original image dimensions -/// 3. Apply mask as alpha channel to create RGBA image -pub fn apply_mask( - original: &DynamicImage, - mask_tensor: Array4, -) -> Result> { +/// The mask is expected to already be grayscale. If its dimensions differ from +/// the original, it is resized with LANCZOS3. +pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result> { println!("Applying mask to remove background..."); let (orig_width, orig_height) = original.dimensions(); - - // Extract mask dimensions - let mask_shape = mask_tensor.shape(); - if mask_shape[0] != 1 || mask_shape[1] != 1 { - return Err(format!("Expected mask shape [1, 1, H, W], got {:?}", mask_shape).into()); - } - - let mask_height = mask_shape[2] as u32; - let mask_width = mask_shape[3] as u32; + let (mask_width, mask_height) = mask.dimensions(); println!("Mask dimensions: {}x{}", mask_width, mask_height); println!("Original dimensions: {}x{}", orig_width, orig_height); - // Create a grayscale image from the mask - let mut mask_image = image::GrayImage::new(mask_width, mask_height); - for y in 0..mask_height { - for x in 0..mask_width { - let mask_value = mask_tensor[[0, 0, y as usize, x as usize]]; - // Clamp and convert to u8 - let pixel_value = (mask_value.clamp(0.0, 1.0) * 255.0) as u8; - mask_image.put_pixel(x, y, image::Luma([pixel_value])); - } - } + let resized_mask: std::borrow::Cow<'_, GrayImage> = + if mask_width != orig_width || mask_height != orig_height { + println!("Resizing mask to match original image..."); + std::borrow::Cow::Owned(image::imageops::resize( + mask, + orig_width, + orig_height, + FilterType::Lanczos3, + )) + } else { + std::borrow::Cow::Borrowed(mask) + }; - // Resize mask to match original image dimensions if needed - let resized_mask = if mask_width != orig_width || mask_height != orig_height { - println!("Resizing mask to match original image..."); - image::imageops::resize(&mask_image, orig_width, orig_height, FilterType::Lanczos3) - } else { - mask_image - }; - - // Convert original to RGBA and apply mask let rgba_original = original.to_rgba8(); let mut result = RgbaImage::new(orig_width, orig_height); @@ -56,16 +35,10 @@ pub fn apply_mask( let orig_pixel = rgba_original.get_pixel(x, y); let mask_pixel = resized_mask.get_pixel(x, y); - // Apply mask as alpha channel result.put_pixel( x, y, - Rgba([ - orig_pixel[0], // R - orig_pixel[1], // G - orig_pixel[2], // B - mask_pixel[0], // Alpha from mask - ]), + Rgba([orig_pixel[0], orig_pixel[1], orig_pixel[2], mask_pixel[0]]), ); } } @@ -83,7 +56,6 @@ pub fn create_side_by_side( let (width, height) = original.dimensions(); let mut composite = RgbaImage::new(width * 2, height); - // Left side: original image let original_rgba = original.to_rgba8(); for y in 0..height { for x in 0..width { @@ -97,12 +69,10 @@ pub fn create_side_by_side( let result_pixel = result.get_pixel(x, y); let alpha = result_pixel[3] as f32 / 255.0; - // Create checkered background (8x8 squares) let checker_size = 8; let is_light = ((x / checker_size) + (y / checker_size)) % 2 == 0; let bg_color = if is_light { 200 } else { 150 }; - // Alpha blend with checkered background let final_r = (result_pixel[0] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; @@ -117,13 +87,16 @@ pub fn create_side_by_side( #[cfg(test)] mod tests { use super::*; - use ndarray::Array4; + use image::{GrayImage, Luma}; #[test] fn test_apply_mask_shape() { let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100)); - let mask = Array4::::ones((1, 1, 100, 100)); - let result = apply_mask(&img, mask).unwrap(); + let mut mask = GrayImage::new(100, 100); + for p in mask.pixels_mut() { + *p = Luma([255]); + } + let result = apply_mask(&img, &mask).unwrap(); assert_eq!(result.dimensions(), (100, 100)); } @@ -132,6 +105,6 @@ mod tests { let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100)); let result_img = RgbaImage::new(100, 100); let composite = create_side_by_side(&img, &result_img).unwrap(); - assert_eq!(composite.dimensions(), (200, 100)); // Double width + assert_eq!(composite.dimensions(), (200, 100)); } } diff --git a/src/preprocessing.rs b/src/preprocessing.rs deleted file mode 100644 index 3f7e1cd..0000000 --- a/src/preprocessing.rs +++ /dev/null @@ -1,67 +0,0 @@ -use image::{DynamicImage, imageops::FilterType}; -use ndarray::Array4; -use std::error::Error; - -/// Preprocess an image for the BiRefNet model -/// -/// Steps: -/// 1. Resize to target dimensions (1024x1024) -/// 2. Convert from u8 [0, 255] to f32 [0.0, 1.0] -/// 3. Rearrange from HWC (Height, Width, Channels) to CHW format -/// 4. Add batch dimension: [1, 3, H, W] -pub fn preprocess_image( - img: &DynamicImage, - target_width: u32, - target_height: u32, -) -> Result, Box> { - println!("Preprocessing image..."); - - // Step 1: Resize image - let resized = img.resize_exact(target_width, target_height, FilterType::Lanczos3); - let rgb_image = resized.to_rgb8(); - - let (width, height) = rgb_image.dimensions(); - - // Step 2: Create ndarray with shape [1, 3, height, width] - let mut array = Array4::::zeros((1, 3, height as usize, width as usize)); - - // Step 3: Fill the array, converting from HWC to CHW and normalizing - for y in 0..height { - for x in 0..width { - let pixel = rgb_image.get_pixel(x, y); - - // Normalize from [0, 255] to [0.0, 1.0] - array[[0, 0, y as usize, x as usize]] = pixel[0] as f32 / 255.0; // R - array[[0, 1, y as usize, x as usize]] = pixel[1] as f32 / 255.0; // G - array[[0, 2, y as usize, x as usize]] = pixel[2] as f32 / 255.0; // B - } - } - - println!("Preprocessing complete. Tensor shape: {:?}", array.shape()); - - Ok(array) -} - -#[cfg(test)] -mod tests { - use super::*; - use image::RgbImage; - - #[test] - fn test_preprocess_shape() { - let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100)); - let result = preprocess_image(&img, 1024, 1024).unwrap(); - assert_eq!(result.shape(), &[1, 3, 1024, 1024]); - } - - #[test] - fn test_preprocess_normalization() { - let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100)); - let result = preprocess_image(&img, 1024, 1024).unwrap(); - - // Check that all values are in [0.0, 1.0] - for &val in result.iter() { - assert!(val >= 0.0 && val <= 1.0); - } - } -} diff --git a/src/sessions/birefnet_lite.rs b/src/sessions/birefnet_lite.rs new file mode 100644 index 0000000..e8f7949 --- /dev/null +++ b/src/sessions/birefnet_lite.rs @@ -0,0 +1,54 @@ +use ort::session::Session as OrtSession; +use std::error::Error; +use std::path::Path; + +use crate::model::{create_session, get_model_path}; + +use super::Session; + +pub struct BiRefNetLiteSession { + inner_session: OrtSession, + input_name: String, +} + +impl BiRefNetLiteSession { + pub fn new(offline: bool) -> Result> { + let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?; + Self::from_model_path(&path) + } + + pub fn from_model_path(path: &Path) -> Result> { + let inner_session = create_session(path)?; + let input_name = inner_session.inputs()[0].name().to_string(); + Ok(Self { + inner_session, + input_name, + }) + } +} + +impl Session for BiRefNetLiteSession { + fn name() -> &'static str { + "birefnet-general-lite" + } + + fn url() -> &'static str { + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx" + } + + fn sha256() -> Option<&'static str> { + Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333") + } + + fn apply_sigmoid(&self) -> bool { + true + } + + fn inner(&mut self) -> &mut OrtSession { + &mut self.inner_session + } + + fn input_name(&self) -> &str { + &self.input_name + } +} diff --git a/src/sessions/bria.rs b/src/sessions/bria.rs new file mode 100644 index 0000000..4cae5f5 --- /dev/null +++ b/src/sessions/bria.rs @@ -0,0 +1,50 @@ +use ort::session::Session as OrtSession; +use std::error::Error; +use std::path::Path; + +use crate::model::{create_session, get_model_path}; + +use super::Session; + +pub struct BriaSession { + inner_session: OrtSession, + input_name: String, +} + +impl BriaSession { + pub fn new(offline: bool) -> Result> { + let path = get_model_path(Self::name(), Self::url(), Self::sha256(), None, offline)?; + Self::from_model_path(&path) + } + + pub fn from_model_path(path: &Path) -> Result> { + let inner_session = create_session(path)?; + let input_name = inner_session.inputs()[0].name().to_string(); + Ok(Self { + inner_session, + input_name, + }) + } +} + +impl Session for BriaSession { + fn name() -> &'static str { + "bria-rmbg" + } + + fn url() -> &'static str { + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx" + } + + fn sha256() -> Option<&'static str> { + Some("5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958") + } + + fn inner(&mut self) -> &mut OrtSession { + &mut self.inner_session + } + + fn input_name(&self) -> &str { + &self.input_name + } +} diff --git a/src/sessions/mod.rs b/src/sessions/mod.rs new file mode 100644 index 0000000..bb8a94f --- /dev/null +++ b/src/sessions/mod.rs @@ -0,0 +1,178 @@ +use image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType}; +use ndarray::{Array4, IntoDimension}; +use ort::{session::Session as OrtSession, value::Tensor}; +use std::error::Error; + +mod birefnet_lite; +mod bria; + +pub use birefnet_lite::BiRefNetLiteSession; +pub use bria::BriaSession; + +/// Common interface for background-removal models, mirroring rembg's +/// `BaseSession` pattern. Each concrete session owns an `ort::Session` and +/// implements `inner` / `input_name`; the rest (preprocessing, inference, +/// postprocessing into a mask) is provided by default implementations. +pub trait Session { + /// Canonical model name (matches rembg's filename stem). + fn name() -> &'static str + where + Self: Sized; + + /// URL to download the ONNX model from. + fn url() -> &'static str + where + Self: Sized; + + /// Optional SHA-256 checksum used to verify the cached model. + fn sha256() -> Option<&'static str> + where + Self: Sized, + { + None + } + + fn mean(&self) -> (f32, f32, f32) { + (0.485, 0.456, 0.406) + } + + fn std(&self) -> (f32, f32, f32) { + (0.229, 0.224, 0.225) + } + + fn input_size(&self) -> (u32, u32) { + (1024, 1024) + } + + /// Whether a sigmoid should be applied to the raw logits before the + /// min/max normalization step. `birefnet-*` needs this; `bria-rmbg` does not. + fn apply_sigmoid(&self) -> bool { + false + } + + fn inner(&mut self) -> &mut OrtSession; + + fn input_name(&self) -> &str; + + /// Port of rembg's `BaseSession.normalize`: resize with LANCZOS, + /// scale into `[0, 1]` by dividing by the max pixel value, then apply + /// channel-wise mean/std. + fn normalize(&self, img: &DynamicImage) -> Result, Box> { + let (w, h) = self.input_size(); + let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_rgb8(); + let (width, height) = resized.dimensions(); + + let mut max_pixel: f32 = 0.0; + for p in resized.pixels() { + for c in 0..3 { + let v = p[c] as f32; + if v > max_pixel { + max_pixel = v; + } + } + } + let denom = max_pixel.max(1e-6); + + let mean = self.mean(); + let std = self.std(); + + let mut array = Array4::::zeros((1, 3, height as usize, width as usize)); + for y in 0..height { + for x in 0..width { + let pixel = resized.get_pixel(x, y); + let r = pixel[0] as f32 / denom; + let g = pixel[1] as f32 / denom; + let b = pixel[2] as f32 / denom; + array[[0, 0, y as usize, x as usize]] = (r - mean.0) / std.0; + array[[0, 1, y as usize, x as usize]] = (g - mean.1) / std.1; + array[[0, 2, y as usize, x as usize]] = (b - mean.2) / std.2; + } + } + + Ok(array) + } + + /// Run inference and return a grayscale mask resized to the input image. + /// Mirrors rembg's `predict`: + /// 1. `inner_session.run(normalize(img, mean, std, size))` + /// 2. take the first output, channel 0 + /// 3. optional sigmoid (birefnet) + /// 4. min/max normalize into `[0, 1]` + /// 5. scale to `u8`, resize to original image dimensions + fn predict(&mut self, img: &DynamicImage) -> Result> { + let (orig_w, orig_h) = img.dimensions(); + let input = self.normalize(img)?; + let apply_sigmoid = self.apply_sigmoid(); + + let input_name = self.input_name().to_string(); + let outputs = self + .inner() + .run(ort::inputs![input_name => Tensor::from_array(input)?])?; + + let output = &outputs[0]; + let (shape, data) = output.try_extract_tensor::()?; + + if shape.len() != 4 { + return Err(format!( + "Expected 4D output tensor [N, C, H, W], got shape {:?}", + shape + ) + .into()); + } + let (n, _c, h, w) = ( + shape[0] as usize, + shape[1] as usize, + shape[2] as usize, + shape[3] as usize, + ); + if n != 1 { + return Err(format!("Expected batch size 1, got {}", n).into()); + } + + let view = ndarray::ArrayView::from_shape( + ( + shape[0] as usize, + shape[1] as usize, + shape[2] as usize, + shape[3] as usize, + ) + .into_dimension(), + data, + )?; + + // Take channel 0: pred = out[:, 0, :, :] + let mut pred: Vec = Vec::with_capacity(h * w); + for y in 0..h { + for x in 0..w { + let mut v = view[[0, 0, y, x]]; + if apply_sigmoid { + v = 1.0 / (1.0 + (-v).exp()); + } + pred.push(v); + } + } + + let (mut mi, mut ma) = (f32::INFINITY, f32::NEG_INFINITY); + for &v in &pred { + if v < mi { + mi = v; + } + if v > ma { + ma = v; + } + } + let range = (ma - mi).max(1e-6); + + let mut mask = GrayImage::new(w as u32, h as u32); + for y in 0..h { + for x in 0..w { + let v = (pred[y * w + x] - mi) / range; + let u = (v.clamp(0.0, 1.0) * 255.0).round() as u8; + mask.put_pixel(x as u32, y as u32, Luma([u])); + } + } + + let mask = image::imageops::resize(&mask, orig_w, orig_h, FilterType::Lanczos3); + Ok(mask) + } +}