wip
This commit is contained in:
parent
9e9721cbbc
commit
22c9315b96
2 changed files with 12 additions and 79 deletions
12
src/main.rs
12
src/main.rs
|
|
@ -9,7 +9,7 @@ mod model;
|
||||||
mod postprocessing;
|
mod postprocessing;
|
||||||
mod preprocessing;
|
mod preprocessing;
|
||||||
|
|
||||||
use model::{ModelInfo, create_session, get_model_path, infer_input_info};
|
use model::{ModelInfo, create_session, get_model_path};
|
||||||
use postprocessing::{apply_mask, create_side_by_side};
|
use postprocessing::{apply_mask, create_side_by_side};
|
||||||
use preprocessing::preprocess_image;
|
use preprocessing::preprocess_image;
|
||||||
|
|
||||||
|
|
@ -54,7 +54,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
println!("Loaded image: {}x{}\n", img_width, img_height);
|
println!("Loaded image: {}x{}\n", img_width, img_height);
|
||||||
|
|
||||||
// Get model info
|
// Get model info
|
||||||
let model_info = ModelInfo::BIREFNET_LITE;
|
let model_info = ModelInfo::BRIA;
|
||||||
println!("Using model: {}", model_info.name);
|
println!("Using model: {}", model_info.name);
|
||||||
|
|
||||||
// Get or download model
|
// Get or download model
|
||||||
|
|
@ -62,16 +62,14 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
|
||||||
// Create ONNX Runtime session
|
// Create ONNX Runtime session
|
||||||
let mut session = create_session(&model_path)?;
|
let mut session = create_session(&model_path)?;
|
||||||
|
let input_name = session.inputs()[0].name().to_string();
|
||||||
let input_info = infer_input_info(&session, img_width, img_height)?;
|
|
||||||
|
|
||||||
// Preprocess
|
// Preprocess
|
||||||
let input_tensor = preprocess_image(&img, input_info.shape.width, input_info.shape.height)?;
|
let input_tensor = preprocess_image(&img, 1024, 1024)?;
|
||||||
|
|
||||||
// Run inference
|
// Run inference
|
||||||
println!("Running inference...");
|
println!("Running inference...");
|
||||||
let outputs =
|
let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?;
|
||||||
session.run(ort::inputs![input_info.name => Tensor::from_array(input_tensor)?])?;
|
|
||||||
|
|
||||||
// Extract mask output
|
// Extract mask output
|
||||||
let mask_output = &outputs[0];
|
let mask_output = &outputs[0];
|
||||||
|
|
|
||||||
79
src/model.rs
79
src/model.rs
|
|
@ -1,8 +1,5 @@
|
||||||
use ort::execution_providers::CUDAExecutionProvider;
|
use ort::execution_providers::CUDAExecutionProvider;
|
||||||
use ort::{
|
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||||
session::{Session, builder::GraphOptimizationLevel},
|
|
||||||
value::ValueType,
|
|
||||||
};
|
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
@ -23,6 +20,12 @@ impl ModelInfo {
|
||||||
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
|
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
|
||||||
sha256: Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333"),
|
sha256: Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333"),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub const BRIA: ModelInfo = ModelInfo {
|
||||||
|
name: "bria",
|
||||||
|
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx",
|
||||||
|
sha256: Some("5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958"),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the cache directory for models
|
/// Get the cache directory for models
|
||||||
|
|
@ -158,74 +161,6 @@ pub fn get_model_path(
|
||||||
Ok(model_path)
|
Ok(model_path)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub struct InputShape {
|
|
||||||
pub batch: u32,
|
|
||||||
pub channels: u32,
|
|
||||||
pub height: u32,
|
|
||||||
pub width: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct InputInfo {
|
|
||||||
pub name: String,
|
|
||||||
pub shape: InputShape,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn infer_input_info(
|
|
||||||
session: &Session,
|
|
||||||
image_width: u32,
|
|
||||||
image_height: u32,
|
|
||||||
) -> Result<InputInfo, Box<dyn Error>> {
|
|
||||||
let input = &session.inputs()[0];
|
|
||||||
let input_name = input.name().to_string();
|
|
||||||
let shape = match input.dtype() {
|
|
||||||
ValueType::Tensor { shape, .. } => shape,
|
|
||||||
_ => return Err("Expected tensor input".into()),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Validate shape has 4 dimensions
|
|
||||||
if shape.len() != 4 {
|
|
||||||
return Err(format!("Expected 4D tensor, got {} dimensions", shape.len()).into());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process each dimension, replacing -1 with appropriate defaults
|
|
||||||
let batch = if shape[0] == -1 {
|
|
||||||
1
|
|
||||||
} else {
|
|
||||||
shape[0].try_into().map_err(|_| "Invalid batch dimension")?
|
|
||||||
};
|
|
||||||
let channels = if shape[1] == -1 {
|
|
||||||
3
|
|
||||||
} else {
|
|
||||||
shape[1]
|
|
||||||
.try_into()
|
|
||||||
.map_err(|_| "Invalid channels dimension")?
|
|
||||||
};
|
|
||||||
let height = if shape[2] == -1 {
|
|
||||||
image_height.min(4096)
|
|
||||||
} else {
|
|
||||||
shape[2]
|
|
||||||
.try_into()
|
|
||||||
.map_err(|_| "Invalid height dimension")?
|
|
||||||
};
|
|
||||||
let width = if shape[3] == -1 {
|
|
||||||
image_width.min(4096)
|
|
||||||
} else {
|
|
||||||
shape[3].try_into().map_err(|_| "Invalid width dimension")?
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(InputInfo {
|
|
||||||
name: input_name,
|
|
||||||
shape: InputShape {
|
|
||||||
batch,
|
|
||||||
channels,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an ONNX Runtime session from model path with CUDA backend
|
/// Create an ONNX Runtime session from model path with CUDA backend
|
||||||
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
||||||
println!("Loading model into ONNX Runtime with CUDA backend...");
|
println!("Loading model into ONNX Runtime with CUDA backend...");
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue