From 609afde7278646c437cefb7d3dc6e31cce93ce11 Mon Sep 17 00:00:00 2001 From: Matthew Deville Date: Thu, 22 Jan 2026 23:36:35 +0100 Subject: [PATCH] wip --- src/main.rs | 41 ++++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/src/main.rs b/src/main.rs index cf64b59..3d18854 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,11 +20,19 @@ struct Args { #[arg(help = "URL of the image to download and process")] url: String, - #[arg(long, help = "Path to custom ONNX model")] + #[arg(long, group = "model_selection", help = "Path to custom ONNX model")] model_path: Option, #[arg(long, help = "Skip model download, fail if not cached")] offline: bool, + + #[arg( + long, + group = "model_selection", + default_value = "bria", + help = "Model to use: 'bria' or 'birefnet-lite' (mutually exclusive with --model-path)" + )] + model: String, } #[show_image::main] @@ -53,12 +61,31 @@ fn main() -> Result<(), Box> { let (img_width, img_height) = img.dimensions(); println!("Loaded image: {}x{}\n", img_width, img_height); - // Get model info - let model_info = ModelInfo::BRIA; - println!("Using model: {}", model_info.name); - - // Get or download model - let model_path = get_model_path(&model_info, args.model_path.as_deref(), args.offline)?; + // Get model path - either from custom path or by selecting built-in model + let model_path = if let Some(custom_path) = args.model_path { + // Custom model path provided, use it directly + let model_path = std::path::PathBuf::from(&custom_path); + if !model_path.exists() { + return Err(format!("Custom model path does not exist: {}", custom_path).into()); + } + println!("Using custom model: {}", custom_path); + model_path + } else { + // Use built-in model + let model_info = match args.model.as_str() { + "bria" => ModelInfo::BRIA, + "birefnet-lite" => ModelInfo::BIREFNET_LITE, + _ => { + return Err(format!( + "Unknown model: {}. Available models: 'bria', 'birefnet-lite'", + args.model + ) + .into()); + } + }; + println!("Using model: {}", model_info.name); + get_model_path(&model_info, None, args.offline)? + }; // Create ONNX Runtime session let mut session = create_session(&model_path)?;