infer input size

This commit is contained in:
Matthew Deville 2026-01-22 21:48:08 +01:00
parent f6961c3177
commit 9e9721cbbc
3 changed files with 84 additions and 12 deletions

View file

@ -20,7 +20,9 @@
overlays = [ (import rust-overlay) ]; overlays = [ (import rust-overlay) ];
pkgs = import nixpkgs { inherit system overlays; }; pkgs = import nixpkgs { inherit system overlays; };
stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv; stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv;
rustToolchain = pkgs.rust-bin.stable.latest.default; rustToolchain = pkgs.rust-bin.stable.latest.default.override {
extensions = [ "rust-src" ];
};
nativeBuildInputs = with pkgs; [ nativeBuildInputs = with pkgs; [
rustToolchain rustToolchain

View file

@ -9,7 +9,7 @@ mod model;
mod postprocessing; mod postprocessing;
mod preprocessing; mod preprocessing;
use model::{ModelInfo, create_session, get_model_path}; use model::{ModelInfo, create_session, get_model_path, infer_input_info};
use postprocessing::{apply_mask, create_side_by_side}; use postprocessing::{apply_mask, create_side_by_side};
use preprocessing::preprocess_image; use preprocessing::preprocess_image;
@ -50,8 +50,8 @@ fn main() -> Result<(), Box<dyn Error>> {
// Load the image // Load the image
let img = image::load_from_memory(&bytes)?; let img = image::load_from_memory(&bytes)?;
let (width, height) = img.dimensions(); let (img_width, img_height) = img.dimensions();
println!("Loaded image: {}x{}\n", width, height); println!("Loaded image: {}x{}\n", img_width, img_height);
// Get model info // Get model info
let model_info = ModelInfo::BIREFNET_LITE; let model_info = ModelInfo::BIREFNET_LITE;
@ -63,14 +63,15 @@ fn main() -> Result<(), Box<dyn Error>> {
// Create ONNX Runtime session // Create ONNX Runtime session
let mut session = create_session(&model_path)?; let mut session = create_session(&model_path)?;
let input_name = session.inputs()[0].name().to_string(); let input_info = infer_input_info(&session, img_width, img_height)?;
// Preprocess image // Preprocess
let input_tensor = preprocess_image(&img, model_info.input_size.0, model_info.input_size.1)?; let input_tensor = preprocess_image(&img, input_info.shape.width, input_info.shape.height)?;
// Run inference // Run inference
println!("Running inference..."); println!("Running inference...");
let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?; let outputs =
session.run(ort::inputs![input_info.name => Tensor::from_array(input_tensor)?])?;
// Extract mask output // Extract mask output
let mask_output = &outputs[0]; let mask_output = &outputs[0];

View file

@ -1,5 +1,8 @@
use ort::execution_providers::CUDAExecutionProvider; use ort::execution_providers::CUDAExecutionProvider;
use ort::session::{Session, builder::GraphOptimizationLevel}; use ort::{
session::{Session, builder::GraphOptimizationLevel},
value::ValueType,
};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use std::error::Error; use std::error::Error;
use std::fs; use std::fs;
@ -11,7 +14,6 @@ pub struct ModelInfo {
pub name: &'static str, pub name: &'static str,
pub url: &'static str, pub url: &'static str,
pub sha256: Option<&'static str>, pub sha256: Option<&'static str>,
pub input_size: (u32, u32),
} }
impl ModelInfo { impl ModelInfo {
@ -19,8 +21,7 @@ impl ModelInfo {
pub const BIREFNET_LITE: ModelInfo = ModelInfo { pub const BIREFNET_LITE: ModelInfo = ModelInfo {
name: "birefnet-general-lite", name: "birefnet-general-lite",
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx", url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
sha256: None, // We'll skip verification for now sha256: Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333"),
input_size: (1024, 1024),
}; };
} }
@ -157,6 +158,74 @@ pub fn get_model_path(
Ok(model_path) Ok(model_path)
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct InputShape {
pub batch: u32,
pub channels: u32,
pub height: u32,
pub width: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct InputInfo {
pub name: String,
pub shape: InputShape,
}
pub fn infer_input_info(
session: &Session,
image_width: u32,
image_height: u32,
) -> Result<InputInfo, Box<dyn Error>> {
let input = &session.inputs()[0];
let input_name = input.name().to_string();
let shape = match input.dtype() {
ValueType::Tensor { shape, .. } => shape,
_ => return Err("Expected tensor input".into()),
};
// Validate shape has 4 dimensions
if shape.len() != 4 {
return Err(format!("Expected 4D tensor, got {} dimensions", shape.len()).into());
}
// Process each dimension, replacing -1 with appropriate defaults
let batch = if shape[0] == -1 {
1
} else {
shape[0].try_into().map_err(|_| "Invalid batch dimension")?
};
let channels = if shape[1] == -1 {
3
} else {
shape[1]
.try_into()
.map_err(|_| "Invalid channels dimension")?
};
let height = if shape[2] == -1 {
image_height.min(4096)
} else {
shape[2]
.try_into()
.map_err(|_| "Invalid height dimension")?
};
let width = if shape[3] == -1 {
image_width.min(4096)
} else {
shape[3].try_into().map_err(|_| "Invalid width dimension")?
};
Ok(InputInfo {
name: input_name,
shape: InputShape {
batch,
channels,
height,
width,
},
})
}
/// Create an ONNX Runtime session from model path with CUDA backend /// Create an ONNX Runtime session from model path with CUDA backend
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> { pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
println!("Loading model into ONNX Runtime with CUDA backend..."); println!("Loading model into ONNX Runtime with CUDA backend...");