From 596cd28dcc081b83216f5302410ba5ab8f294138 Mon Sep 17 00:00:00 2001 From: Matthew Deville Date: Fri, 23 Jan 2026 13:26:42 +0100 Subject: [PATCH] typed models --- src/main.rs | 19 +++++-------------- src/model.rs | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/main.rs b/src/main.rs index 1ddd536..a816e59 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ mod model; mod postprocessing; mod preprocessing; -use model::{ModelInfo, create_session, get_model_path}; +use model::{Model, create_session, get_model_path}; use postprocessing::{apply_mask, create_side_by_side}; use preprocessing::preprocess_image; @@ -24,10 +24,11 @@ struct Args { short, long, group = "model_selection", - default_value = "bria", + value_enum, + default_value_t = Model::Bria, help = "Model to use: 'bria' or 'birefnet-lite' (mutually exclusive with --model-path)" )] - model: String, + model: Model, #[arg(long, group = "model_selection", help = "Path to custom ONNX model")] model_path: Option, @@ -73,17 +74,7 @@ fn main() -> Result<(), Box> { model_path } else { // Use built-in model - let model_info = match args.model.as_str() { - "bria" => ModelInfo::BRIA, - "birefnet-lite" => ModelInfo::BIREFNET_LITE, - _ => { - return Err(format!( - "Unknown model: {}. Available models: 'bria', 'birefnet-lite'", - args.model - ) - .into()); - } - }; + let model_info = args.model.info(); println!("Using model: {}", model_info.name); get_model_path(&model_info, None, args.offline)? }; diff --git a/src/model.rs b/src/model.rs index 3b38490..ad24bff 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,3 +1,4 @@ +use clap::ValueEnum; use ort::execution_providers::CUDAExecutionProvider; use ort::session::{Session, builder::GraphOptimizationLevel}; use sha2::{Digest, Sha256}; @@ -6,6 +7,22 @@ use std::fs; use std::io::Write; use std::path::{Path, PathBuf}; +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +#[clap(rename_all = "kebab-case")] +pub enum Model { + BiRefNetLite, + Bria, +} + +impl Model { + pub fn info(&self) -> &ModelInfo { + match self { + Model::BiRefNetLite => &ModelInfo::BIREFNET_LITE, + Model::Bria => &ModelInfo::BRIA, + } + } +} + /// Model metadata pub struct ModelInfo { pub name: &'static str,