remove_background/src/sessions/mod.rs

179 lines
5.5 KiB
Rust
Raw Normal View History

2026-04-22 14:44:55 +00:00
use image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType};
use ndarray::{Array4, IntoDimension};
use ort::{session::Session as OrtSession, value::Tensor};
use std::error::Error;
mod birefnet_lite;
mod bria;
pub use birefnet_lite::BiRefNetLiteSession;
pub use bria::BriaSession;
/// Common interface for background-removal models, mirroring rembg's
/// `BaseSession` pattern. Each concrete session owns an `ort::Session` and
/// implements `inner` / `input_name`; the rest (preprocessing, inference,
/// postprocessing into a mask) is provided by default implementations.
pub trait Session {
/// Canonical model name (matches rembg's filename stem).
fn name() -> &'static str
where
Self: Sized;
/// URL to download the ONNX model from.
fn url() -> &'static str
where
Self: Sized;
/// Optional SHA-256 checksum used to verify the cached model.
fn sha256() -> Option<&'static str>
where
Self: Sized,
{
None
}
fn mean(&self) -> (f32, f32, f32) {
(0.485, 0.456, 0.406)
}
fn std(&self) -> (f32, f32, f32) {
(0.229, 0.224, 0.225)
}
fn input_size(&self) -> (u32, u32) {
(1024, 1024)
}
/// Whether a sigmoid should be applied to the raw logits before the
/// min/max normalization step. `birefnet-*` needs this; `bria-rmbg` does not.
fn apply_sigmoid(&self) -> bool {
false
}
fn inner(&mut self) -> &mut OrtSession;
fn input_name(&self) -> &str;
/// Port of rembg's `BaseSession.normalize`: resize with LANCZOS,
/// scale into `[0, 1]` by dividing by the max pixel value, then apply
/// channel-wise mean/std.
fn normalize(&self, img: &DynamicImage) -> Result<Array4<f32>, Box<dyn Error>> {
let (w, h) = self.input_size();
let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_rgb8();
let (width, height) = resized.dimensions();
let mut max_pixel: f32 = 0.0;
for p in resized.pixels() {
for c in 0..3 {
let v = p[c] as f32;
if v > max_pixel {
max_pixel = v;
}
}
}
let denom = max_pixel.max(1e-6);
let mean = self.mean();
let std = self.std();
let mut array = Array4::<f32>::zeros((1, 3, height as usize, width as usize));
for y in 0..height {
for x in 0..width {
let pixel = resized.get_pixel(x, y);
let r = pixel[0] as f32 / denom;
let g = pixel[1] as f32 / denom;
let b = pixel[2] as f32 / denom;
array[[0, 0, y as usize, x as usize]] = (r - mean.0) / std.0;
array[[0, 1, y as usize, x as usize]] = (g - mean.1) / std.1;
array[[0, 2, y as usize, x as usize]] = (b - mean.2) / std.2;
}
}
Ok(array)
}
/// Run inference and return a grayscale mask resized to the input image.
/// Mirrors rembg's `predict`:
/// 1. `inner_session.run(normalize(img, mean, std, size))`
/// 2. take the first output, channel 0
/// 3. optional sigmoid (birefnet)
/// 4. min/max normalize into `[0, 1]`
/// 5. scale to `u8`, resize to original image dimensions
fn predict(&mut self, img: &DynamicImage) -> Result<GrayImage, Box<dyn Error>> {
let (orig_w, orig_h) = img.dimensions();
let input = self.normalize(img)?;
let apply_sigmoid = self.apply_sigmoid();
let input_name = self.input_name().to_string();
let outputs = self
.inner()
.run(ort::inputs![input_name => Tensor::from_array(input)?])?;
let output = &outputs[0];
let (shape, data) = output.try_extract_tensor::<f32>()?;
if shape.len() != 4 {
return Err(format!(
"Expected 4D output tensor [N, C, H, W], got shape {:?}",
shape
)
.into());
}
let (n, _c, h, w) = (
shape[0] as usize,
shape[1] as usize,
shape[2] as usize,
shape[3] as usize,
);
if n != 1 {
return Err(format!("Expected batch size 1, got {}", n).into());
}
let view = ndarray::ArrayView::from_shape(
(
shape[0] as usize,
shape[1] as usize,
shape[2] as usize,
shape[3] as usize,
)
.into_dimension(),
data,
)?;
// Take channel 0: pred = out[:, 0, :, :]
let mut pred: Vec<f32> = Vec::with_capacity(h * w);
for y in 0..h {
for x in 0..w {
let mut v = view[[0, 0, y, x]];
if apply_sigmoid {
v = 1.0 / (1.0 + (-v).exp());
}
pred.push(v);
}
}
let (mut mi, mut ma) = (f32::INFINITY, f32::NEG_INFINITY);
for &v in &pred {
if v < mi {
mi = v;
}
if v > ma {
ma = v;
}
}
let range = (ma - mi).max(1e-6);
let mut mask = GrayImage::new(w as u32, h as u32);
for y in 0..h {
for x in 0..w {
let v = (pred[y * w + x] - mi) / range;
let u = (v.clamp(0.0, 1.0) * 255.0).round() as u8;
mask.put_pixel(x as u32, y as u32, Luma([u]));
}
}
let mask = image::imageops::resize(&mask, orig_w, orig_h, FilterType::Lanczos3);
Ok(mask)
}
}