2026-04-22 20:09:13 +00:00
|
|
|
use {
|
|
|
|
|
image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType},
|
|
|
|
|
ndarray::{Array4, IntoDimension},
|
|
|
|
|
ort::{session::Session as OrtSession, value::Tensor},
|
|
|
|
|
std::error::Error,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
use crate::model::Model;
|
2026-04-22 14:44:55 +00:00
|
|
|
|
|
|
|
|
mod birefnet_lite;
|
|
|
|
|
mod bria;
|
|
|
|
|
|
|
|
|
|
pub use birefnet_lite::BiRefNetLiteSession;
|
|
|
|
|
pub use bria::BriaSession;
|
|
|
|
|
|
|
|
|
|
/// Common interface for background-removal models, mirroring rembg's
|
|
|
|
|
/// `BaseSession` pattern. Each concrete session owns an `ort::Session` and
|
|
|
|
|
/// implements `inner` / `input_name`; the rest (preprocessing, inference,
|
|
|
|
|
/// postprocessing into a mask) is provided by default implementations.
|
|
|
|
|
pub trait Session {
|
|
|
|
|
/// Canonical model name (matches rembg's filename stem).
|
|
|
|
|
fn name() -> &'static str
|
|
|
|
|
where
|
|
|
|
|
Self: Sized;
|
|
|
|
|
|
|
|
|
|
/// URL to download the ONNX model from.
|
|
|
|
|
fn url() -> &'static str
|
|
|
|
|
where
|
|
|
|
|
Self: Sized;
|
|
|
|
|
|
|
|
|
|
/// Optional SHA-256 checksum used to verify the cached model.
|
|
|
|
|
fn sha256() -> Option<&'static str>
|
|
|
|
|
where
|
|
|
|
|
Self: Sized,
|
|
|
|
|
{
|
|
|
|
|
None
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn mean(&self) -> (f32, f32, f32) {
|
|
|
|
|
(0.485, 0.456, 0.406)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn std(&self) -> (f32, f32, f32) {
|
|
|
|
|
(0.229, 0.224, 0.225)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn input_size(&self) -> (u32, u32) {
|
|
|
|
|
(1024, 1024)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Whether a sigmoid should be applied to the raw logits before the
|
|
|
|
|
/// min/max normalization step. `birefnet-*` needs this; `bria-rmbg` does not.
|
|
|
|
|
fn apply_sigmoid(&self) -> bool {
|
|
|
|
|
false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn inner(&mut self) -> &mut OrtSession;
|
|
|
|
|
|
|
|
|
|
fn input_name(&self) -> &str;
|
|
|
|
|
|
|
|
|
|
/// Port of rembg's `BaseSession.normalize`: resize with LANCZOS,
|
|
|
|
|
/// scale into `[0, 1]` by dividing by the max pixel value, then apply
|
|
|
|
|
/// channel-wise mean/std.
|
|
|
|
|
fn normalize(&self, img: &DynamicImage) -> Result<Array4<f32>, Box<dyn Error>> {
|
|
|
|
|
let (w, h) = self.input_size();
|
|
|
|
|
let resized = img.resize_exact(w, h, FilterType::Lanczos3).to_rgb8();
|
|
|
|
|
let (width, height) = resized.dimensions();
|
|
|
|
|
|
|
|
|
|
let mut max_pixel: f32 = 0.0;
|
|
|
|
|
for p in resized.pixels() {
|
|
|
|
|
for c in 0..3 {
|
|
|
|
|
let v = p[c] as f32;
|
|
|
|
|
if v > max_pixel {
|
|
|
|
|
max_pixel = v;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let denom = max_pixel.max(1e-6);
|
|
|
|
|
|
|
|
|
|
let mean = self.mean();
|
|
|
|
|
let std = self.std();
|
|
|
|
|
|
|
|
|
|
let mut array = Array4::<f32>::zeros((1, 3, height as usize, width as usize));
|
|
|
|
|
for y in 0..height {
|
|
|
|
|
for x in 0..width {
|
|
|
|
|
let pixel = resized.get_pixel(x, y);
|
|
|
|
|
let r = pixel[0] as f32 / denom;
|
|
|
|
|
let g = pixel[1] as f32 / denom;
|
|
|
|
|
let b = pixel[2] as f32 / denom;
|
|
|
|
|
array[[0, 0, y as usize, x as usize]] = (r - mean.0) / std.0;
|
|
|
|
|
array[[0, 1, y as usize, x as usize]] = (g - mean.1) / std.1;
|
|
|
|
|
array[[0, 2, y as usize, x as usize]] = (b - mean.2) / std.2;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(array)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Run inference and return a grayscale mask resized to the input image.
|
|
|
|
|
/// Mirrors rembg's `predict`:
|
|
|
|
|
/// 1. `inner_session.run(normalize(img, mean, std, size))`
|
|
|
|
|
/// 2. take the first output, channel 0
|
|
|
|
|
/// 3. optional sigmoid (birefnet)
|
|
|
|
|
/// 4. min/max normalize into `[0, 1]`
|
|
|
|
|
/// 5. scale to `u8`, resize to original image dimensions
|
|
|
|
|
fn predict(&mut self, img: &DynamicImage) -> Result<GrayImage, Box<dyn Error>> {
|
|
|
|
|
let (orig_w, orig_h) = img.dimensions();
|
|
|
|
|
let input = self.normalize(img)?;
|
|
|
|
|
let apply_sigmoid = self.apply_sigmoid();
|
|
|
|
|
|
|
|
|
|
let input_name = self.input_name().to_string();
|
|
|
|
|
let outputs = self
|
|
|
|
|
.inner()
|
|
|
|
|
.run(ort::inputs![input_name => Tensor::from_array(input)?])?;
|
|
|
|
|
|
|
|
|
|
let output = &outputs[0];
|
|
|
|
|
let (shape, data) = output.try_extract_tensor::<f32>()?;
|
|
|
|
|
|
|
|
|
|
if shape.len() != 4 {
|
|
|
|
|
return Err(format!(
|
|
|
|
|
"Expected 4D output tensor [N, C, H, W], got shape {:?}",
|
|
|
|
|
shape
|
|
|
|
|
)
|
|
|
|
|
.into());
|
|
|
|
|
}
|
|
|
|
|
let (n, _c, h, w) = (
|
|
|
|
|
shape[0] as usize,
|
|
|
|
|
shape[1] as usize,
|
|
|
|
|
shape[2] as usize,
|
|
|
|
|
shape[3] as usize,
|
|
|
|
|
);
|
|
|
|
|
if n != 1 {
|
|
|
|
|
return Err(format!("Expected batch size 1, got {}", n).into());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let view = ndarray::ArrayView::from_shape(
|
|
|
|
|
(
|
|
|
|
|
shape[0] as usize,
|
|
|
|
|
shape[1] as usize,
|
|
|
|
|
shape[2] as usize,
|
|
|
|
|
shape[3] as usize,
|
|
|
|
|
)
|
|
|
|
|
.into_dimension(),
|
|
|
|
|
data,
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
// Take channel 0: pred = out[:, 0, :, :]
|
|
|
|
|
let mut pred: Vec<f32> = Vec::with_capacity(h * w);
|
|
|
|
|
for y in 0..h {
|
|
|
|
|
for x in 0..w {
|
|
|
|
|
let mut v = view[[0, 0, y, x]];
|
|
|
|
|
if apply_sigmoid {
|
|
|
|
|
v = 1.0 / (1.0 + (-v).exp());
|
|
|
|
|
}
|
|
|
|
|
pred.push(v);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let (mut mi, mut ma) = (f32::INFINITY, f32::NEG_INFINITY);
|
|
|
|
|
for &v in &pred {
|
|
|
|
|
if v < mi {
|
|
|
|
|
mi = v;
|
|
|
|
|
}
|
|
|
|
|
if v > ma {
|
|
|
|
|
ma = v;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let range = (ma - mi).max(1e-6);
|
|
|
|
|
|
|
|
|
|
let mut mask = GrayImage::new(w as u32, h as u32);
|
|
|
|
|
for y in 0..h {
|
|
|
|
|
for x in 0..w {
|
|
|
|
|
let v = (pred[y * w + x] - mi) / range;
|
|
|
|
|
let u = (v.clamp(0.0, 1.0) * 255.0).round() as u8;
|
|
|
|
|
mask.put_pixel(x as u32, y as u32, Luma([u]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let mask = image::imageops::resize(&mask, orig_w, orig_h, FilterType::Lanczos3);
|
|
|
|
|
Ok(mask)
|
|
|
|
|
}
|
|
|
|
|
}
|
2026-04-22 20:09:13 +00:00
|
|
|
|
|
|
|
|
pub fn init_session(
|
|
|
|
|
custom_model_path: Option<&str>,
|
|
|
|
|
model: Model,
|
|
|
|
|
offline: bool,
|
|
|
|
|
) -> Result<Box<dyn Session>, Box<dyn Error>> {
|
|
|
|
|
Ok(if let Some(custom_path) = custom_model_path {
|
|
|
|
|
let path = std::path::PathBuf::from(custom_path);
|
|
|
|
|
if !path.exists() {
|
|
|
|
|
return Err(format!("Custom model path does not exist: {}", custom_path).into());
|
|
|
|
|
}
|
|
|
|
|
println!("Using custom model: {}", custom_path);
|
|
|
|
|
match model {
|
|
|
|
|
Model::Bria => Box::new(BriaSession::from_model_path(&path)?),
|
|
|
|
|
Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?),
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
match model {
|
|
|
|
|
Model::Bria => {
|
|
|
|
|
println!("Using model: bria-rmbg");
|
|
|
|
|
Box::new(BriaSession::new(offline)?)
|
|
|
|
|
}
|
|
|
|
|
Model::BiRefNetLite => {
|
|
|
|
|
println!("Using model: birefnet-general-lite");
|
|
|
|
|
Box::new(BiRefNetLiteSession::new(offline)?)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|