use { image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType}, ndarray::{Array4, IntoDimension}, ort::{session::Session as OrtSession, value::Tensor}, std::error::Error, }; use crate::model::Model; mod birefnet_lite; mod bria; pub use birefnet_lite::BiRefNetLiteSession; pub use bria::BriaSession; /// Common interface for background-removal models, mirroring rembg's /// `BaseSession` pattern. Each concrete session owns an `ort::Session` and /// implements `inner` / `input_name`; the rest (preprocessing, inference, /// postprocessing into a mask) is provided by default implementations. pub trait Session { /// Canonical model name (matches rembg's filename stem). fn name() -> &'static str where Self: Sized; /// URL to download the ONNX model from. fn url() -> &'static str where Self: Sized; /// Optional SHA-256 checksum used to verify the cached model. fn sha256() -> Option<&'static str> where Self: Sized, { None } fn mean(&self) -> (f32, f32, f32) { (0.485, 0.456, 0.406) } fn std(&self) -> (f32, f32, f32) { (0.229, 0.224, 0.225) } fn input_size(&self) -> (u32, u32) { (1024, 1024) } /// Whether a sigmoid should be applied to the raw logits before the /// min/max normalization step. `birefnet-*` needs this; `bria-rmbg` does not. fn apply_sigmoid(&self) -> bool { false } fn inner(&mut self) -> &mut OrtSession; fn input_name(&self) -> &str; /// Port of rembg's `BaseSession.normalize`: resize with LANCZOS, /// scale into `[0, 1]` by dividing by the max pixel value, then apply /// channel-wise mean/std. fn normalize(&self, img: &DynamicImage) -> Result, Box> { let (w, h) = self.input_size(); let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_rgb8(); let (width, height) = resized.dimensions(); let mut max_pixel: f32 = 0.0; for p in resized.pixels() { for c in 0..3 { let v = p[c] as f32; if v > max_pixel { max_pixel = v; } } } let denom = max_pixel.max(1e-6); let mean = self.mean(); let std = self.std(); let mut array = Array4::::zeros((1, 3, height as usize, width as usize)); for y in 0..height { for x in 0..width { let pixel = resized.get_pixel(x, y); let r = pixel[0] as f32 / denom; let g = pixel[1] as f32 / denom; let b = pixel[2] as f32 / denom; array[[0, 0, y as usize, x as usize]] = (r - mean.0) / std.0; array[[0, 1, y as usize, x as usize]] = (g - mean.1) / std.1; array[[0, 2, y as usize, x as usize]] = (b - mean.2) / std.2; } } Ok(array) } /// Run inference and return a grayscale mask resized to the input image. /// Mirrors rembg's `predict`: /// 1. `inner_session.run(normalize(img, mean, std, size))` /// 2. take the first output, channel 0 /// 3. optional sigmoid (birefnet) /// 4. min/max normalize into `[0, 1]` /// 5. scale to `u8`, resize to original image dimensions fn predict(&mut self, img: &DynamicImage) -> Result> { let (orig_w, orig_h) = img.dimensions(); let input = self.normalize(img)?; let apply_sigmoid = self.apply_sigmoid(); let input_name = self.input_name().to_string(); let outputs = self .inner() .run(ort::inputs![input_name => Tensor::from_array(input)?])?; let output = &outputs[0]; let (shape, data) = output.try_extract_tensor::()?; if shape.len() != 4 { return Err(format!( "Expected 4D output tensor [N, C, H, W], got shape {:?}", shape ) .into()); } let (n, _c, h, w) = ( shape[0] as usize, shape[1] as usize, shape[2] as usize, shape[3] as usize, ); if n != 1 { return Err(format!("Expected batch size 1, got {}", n).into()); } let view = ndarray::ArrayView::from_shape( ( shape[0] as usize, shape[1] as usize, shape[2] as usize, shape[3] as usize, ) .into_dimension(), data, )?; // Take channel 0: pred = out[:, 0, :, :] let mut pred: Vec = Vec::with_capacity(h * w); for y in 0..h { for x in 0..w { let mut v = view[[0, 0, y, x]]; if apply_sigmoid { v = 1.0 / (1.0 + (-v).exp()); } pred.push(v); } } let (mut mi, mut ma) = (f32::INFINITY, f32::NEG_INFINITY); for &v in &pred { if v < mi { mi = v; } if v > ma { ma = v; } } let range = (ma - mi).max(1e-6); let mut mask = GrayImage::new(w as u32, h as u32); for y in 0..h { for x in 0..w { let v = (pred[y * w + x] - mi) / range; let u = (v.clamp(0.0, 1.0) * 255.0).round() as u8; mask.put_pixel(x as u32, y as u32, Luma([u])); } } let mask = image::imageops::resize(&mask, orig_w, orig_h, FilterType::Lanczos3); Ok(mask) } } pub fn init_session( custom_model_path: Option<&str>, model: Model, offline: bool, ) -> Result, Box> { Ok(if let Some(custom_path) = custom_model_path { let path = std::path::PathBuf::from(custom_path); if !path.exists() { return Err(format!("Custom model path does not exist: {}", custom_path).into()); } println!("Using custom model: {}", custom_path); match model { Model::Bria => Box::new(BriaSession::from_model_path(&path)?), Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?), } } else { match model { Model::Bria => { println!("Using model: bria-rmbg"); Box::new(BriaSession::new(offline)?) } Model::BiRefNetLite => { println!("Using model: birefnet-general-lite"); Box::new(BiRefNetLiteSession::new(offline)?) } } }) }