diff --git a/flake.nix b/flake.nix index a7257a5..6acc2d1 100644 --- a/flake.nix +++ b/flake.nix @@ -20,7 +20,9 @@ overlays = [ (import rust-overlay) ]; pkgs = import nixpkgs { inherit system overlays; }; stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv; - rustToolchain = pkgs.rust-bin.stable.latest.default; + rustToolchain = pkgs.rust-bin.stable.latest.default.override { + extensions = [ "rust-src" ]; + }; nativeBuildInputs = with pkgs; [ rustToolchain diff --git a/src/main.rs b/src/main.rs index 5fc4158..ef26557 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ mod model; mod postprocessing; mod preprocessing; -use model::{ModelInfo, create_session, get_model_path}; +use model::{ModelInfo, create_session, get_model_path, infer_input_info}; use postprocessing::{apply_mask, create_side_by_side}; use preprocessing::preprocess_image; @@ -50,8 +50,8 @@ fn main() -> Result<(), Box> { // Load the image let img = image::load_from_memory(&bytes)?; - let (width, height) = img.dimensions(); - println!("Loaded image: {}x{}\n", width, height); + let (img_width, img_height) = img.dimensions(); + println!("Loaded image: {}x{}\n", img_width, img_height); // Get model info let model_info = ModelInfo::BIREFNET_LITE; @@ -63,14 +63,15 @@ fn main() -> Result<(), Box> { // Create ONNX Runtime session let mut session = create_session(&model_path)?; - let input_name = session.inputs()[0].name().to_string(); + let input_info = infer_input_info(&session, img_width, img_height)?; - // Preprocess image - let input_tensor = preprocess_image(&img, model_info.input_size.0, model_info.input_size.1)?; + // Preprocess + let input_tensor = preprocess_image(&img, input_info.shape.width, input_info.shape.height)?; // Run inference println!("Running inference..."); - let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?; + let outputs = + session.run(ort::inputs![input_info.name => Tensor::from_array(input_tensor)?])?; // Extract mask output let mask_output = &outputs[0]; diff --git a/src/model.rs b/src/model.rs index dc17fff..ec8bad7 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,5 +1,8 @@ use ort::execution_providers::CUDAExecutionProvider; -use ort::session::{Session, builder::GraphOptimizationLevel}; +use ort::{ + session::{Session, builder::GraphOptimizationLevel}, + value::ValueType, +}; use sha2::{Digest, Sha256}; use std::error::Error; use std::fs; @@ -11,7 +14,6 @@ pub struct ModelInfo { pub name: &'static str, pub url: &'static str, pub sha256: Option<&'static str>, - pub input_size: (u32, u32), } impl ModelInfo { @@ -19,8 +21,7 @@ impl ModelInfo { pub const BIREFNET_LITE: ModelInfo = ModelInfo { name: "birefnet-general-lite", url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx", - sha256: None, // We'll skip verification for now - input_size: (1024, 1024), + sha256: Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333"), }; } @@ -157,6 +158,74 @@ pub fn get_model_path( Ok(model_path) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct InputShape { + pub batch: u32, + pub channels: u32, + pub height: u32, + pub width: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct InputInfo { + pub name: String, + pub shape: InputShape, +} + +pub fn infer_input_info( + session: &Session, + image_width: u32, + image_height: u32, +) -> Result> { + let input = &session.inputs()[0]; + let input_name = input.name().to_string(); + let shape = match input.dtype() { + ValueType::Tensor { shape, .. } => shape, + _ => return Err("Expected tensor input".into()), + }; + + // Validate shape has 4 dimensions + if shape.len() != 4 { + return Err(format!("Expected 4D tensor, got {} dimensions", shape.len()).into()); + } + + // Process each dimension, replacing -1 with appropriate defaults + let batch = if shape[0] == -1 { + 1 + } else { + shape[0].try_into().map_err(|_| "Invalid batch dimension")? + }; + let channels = if shape[1] == -1 { + 3 + } else { + shape[1] + .try_into() + .map_err(|_| "Invalid channels dimension")? + }; + let height = if shape[2] == -1 { + image_height.min(4096) + } else { + shape[2] + .try_into() + .map_err(|_| "Invalid height dimension")? + }; + let width = if shape[3] == -1 { + image_width.min(4096) + } else { + shape[3].try_into().map_err(|_| "Invalid width dimension")? + }; + + Ok(InputInfo { + name: input_name, + shape: InputShape { + batch, + channels, + height, + width, + }, + }) +} + /// Create an ONNX Runtime session from model path with CUDA backend pub fn create_session(model_path: &Path) -> Result> { println!("Loading model into ONNX Runtime with CUDA backend...");