This commit is contained in:
Matthew Deville 2026-01-22 23:36:35 +01:00
parent 22c9315b96
commit 609afde727

View file

@ -20,11 +20,19 @@ struct Args {
#[arg(help = "URL of the image to download and process")] #[arg(help = "URL of the image to download and process")]
url: String, url: String,
#[arg(long, help = "Path to custom ONNX model")] #[arg(long, group = "model_selection", help = "Path to custom ONNX model")]
model_path: Option<String>, model_path: Option<String>,
#[arg(long, help = "Skip model download, fail if not cached")] #[arg(long, help = "Skip model download, fail if not cached")]
offline: bool, offline: bool,
#[arg(
long,
group = "model_selection",
default_value = "bria",
help = "Model to use: 'bria' or 'birefnet-lite' (mutually exclusive with --model-path)"
)]
model: String,
} }
#[show_image::main] #[show_image::main]
@ -53,12 +61,31 @@ fn main() -> Result<(), Box<dyn Error>> {
let (img_width, img_height) = img.dimensions(); let (img_width, img_height) = img.dimensions();
println!("Loaded image: {}x{}\n", img_width, img_height); println!("Loaded image: {}x{}\n", img_width, img_height);
// Get model info // Get model path - either from custom path or by selecting built-in model
let model_info = ModelInfo::BRIA; let model_path = if let Some(custom_path) = args.model_path {
// Custom model path provided, use it directly
let model_path = std::path::PathBuf::from(&custom_path);
if !model_path.exists() {
return Err(format!("Custom model path does not exist: {}", custom_path).into());
}
println!("Using custom model: {}", custom_path);
model_path
} else {
// Use built-in model
let model_info = match args.model.as_str() {
"bria" => ModelInfo::BRIA,
"birefnet-lite" => ModelInfo::BIREFNET_LITE,
_ => {
return Err(format!(
"Unknown model: {}. Available models: 'bria', 'birefnet-lite'",
args.model
)
.into());
}
};
println!("Using model: {}", model_info.name); println!("Using model: {}", model_info.name);
get_model_path(&model_info, None, args.offline)?
// Get or download model };
let model_path = get_model_path(&model_info, args.model_path.as_deref(), args.offline)?;
// Create ONNX Runtime session // Create ONNX Runtime session
let mut session = create_session(&model_path)?; let mut session = create_session(&model_path)?;