typed models

This commit is contained in:
Matthew Deville 2026-01-23 13:26:42 +01:00
parent 6e965c4a18
commit 596cd28dcc
2 changed files with 22 additions and 14 deletions

View file

@ -9,7 +9,7 @@ mod model;
mod postprocessing;
mod preprocessing;
use model::{ModelInfo, create_session, get_model_path};
use model::{Model, create_session, get_model_path};
use postprocessing::{apply_mask, create_side_by_side};
use preprocessing::preprocess_image;
@ -24,10 +24,11 @@ struct Args {
short,
long,
group = "model_selection",
default_value = "bria",
value_enum,
default_value_t = Model::Bria,
help = "Model to use: 'bria' or 'birefnet-lite' (mutually exclusive with --model-path)"
)]
model: String,
model: Model,
#[arg(long, group = "model_selection", help = "Path to custom ONNX model")]
model_path: Option<String>,
@ -73,17 +74,7 @@ fn main() -> Result<(), Box<dyn Error>> {
model_path
} else {
// Use built-in model
let model_info = match args.model.as_str() {
"bria" => ModelInfo::BRIA,
"birefnet-lite" => ModelInfo::BIREFNET_LITE,
_ => {
return Err(format!(
"Unknown model: {}. Available models: 'bria', 'birefnet-lite'",
args.model
)
.into());
}
};
let model_info = args.model.info();
println!("Using model: {}", model_info.name);
get_model_path(&model_info, None, args.offline)?
};

View file

@ -1,3 +1,4 @@
use clap::ValueEnum;
use ort::execution_providers::CUDAExecutionProvider;
use ort::session::{Session, builder::GraphOptimizationLevel};
use sha2::{Digest, Sha256};
@ -6,6 +7,22 @@ use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
#[clap(rename_all = "kebab-case")]
pub enum Model {
BiRefNetLite,
Bria,
}
impl Model {
pub fn info(&self) -> &ModelInfo {
match self {
Model::BiRefNetLite => &ModelInfo::BIREFNET_LITE,
Model::Bria => &ModelInfo::BRIA,
}
}
}
/// Model metadata
pub struct ModelInfo {
pub name: &'static str,