diff --git a/src/main.rs b/src/main.rs index ef26557..cf64b59 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ mod model; mod postprocessing; mod preprocessing; -use model::{ModelInfo, create_session, get_model_path, infer_input_info}; +use model::{ModelInfo, create_session, get_model_path}; use postprocessing::{apply_mask, create_side_by_side}; use preprocessing::preprocess_image; @@ -54,7 +54,7 @@ fn main() -> Result<(), Box> { println!("Loaded image: {}x{}\n", img_width, img_height); // Get model info - let model_info = ModelInfo::BIREFNET_LITE; + let model_info = ModelInfo::BRIA; println!("Using model: {}", model_info.name); // Get or download model @@ -62,16 +62,14 @@ fn main() -> Result<(), Box> { // Create ONNX Runtime session let mut session = create_session(&model_path)?; - - let input_info = infer_input_info(&session, img_width, img_height)?; + let input_name = session.inputs()[0].name().to_string(); // Preprocess - let input_tensor = preprocess_image(&img, input_info.shape.width, input_info.shape.height)?; + let input_tensor = preprocess_image(&img, 1024, 1024)?; // Run inference println!("Running inference..."); - let outputs = - session.run(ort::inputs![input_info.name => Tensor::from_array(input_tensor)?])?; + let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?; // Extract mask output let mask_output = &outputs[0]; diff --git a/src/model.rs b/src/model.rs index ec8bad7..3b38490 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,8 +1,5 @@ use ort::execution_providers::CUDAExecutionProvider; -use ort::{ - session::{Session, builder::GraphOptimizationLevel}, - value::ValueType, -}; +use ort::session::{Session, builder::GraphOptimizationLevel}; use sha2::{Digest, Sha256}; use std::error::Error; use std::fs; @@ -23,6 +20,12 @@ impl ModelInfo { url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx", sha256: Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333"), }; + + pub const BRIA: ModelInfo = ModelInfo { + name: "bria", + url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx", + sha256: Some("5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958"), + }; } /// Get the cache directory for models @@ -158,74 +161,6 @@ pub fn get_model_path( Ok(model_path) } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct InputShape { - pub batch: u32, - pub channels: u32, - pub height: u32, - pub width: u32, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct InputInfo { - pub name: String, - pub shape: InputShape, -} - -pub fn infer_input_info( - session: &Session, - image_width: u32, - image_height: u32, -) -> Result> { - let input = &session.inputs()[0]; - let input_name = input.name().to_string(); - let shape = match input.dtype() { - ValueType::Tensor { shape, .. } => shape, - _ => return Err("Expected tensor input".into()), - }; - - // Validate shape has 4 dimensions - if shape.len() != 4 { - return Err(format!("Expected 4D tensor, got {} dimensions", shape.len()).into()); - } - - // Process each dimension, replacing -1 with appropriate defaults - let batch = if shape[0] == -1 { - 1 - } else { - shape[0].try_into().map_err(|_| "Invalid batch dimension")? - }; - let channels = if shape[1] == -1 { - 3 - } else { - shape[1] - .try_into() - .map_err(|_| "Invalid channels dimension")? - }; - let height = if shape[2] == -1 { - image_height.min(4096) - } else { - shape[2] - .try_into() - .map_err(|_| "Invalid height dimension")? - }; - let width = if shape[3] == -1 { - image_width.min(4096) - } else { - shape[3].try_into().map_err(|_| "Invalid width dimension")? - }; - - Ok(InputInfo { - name: input_name, - shape: InputShape { - batch, - channels, - height, - width, - }, - }) -} - /// Create an ONNX Runtime session from model path with CUDA backend pub fn create_session(model_path: &Path) -> Result> { println!("Loading model into ONNX Runtime with CUDA backend...");