typed models
This commit is contained in:
parent
6e965c4a18
commit
596cd28dcc
2 changed files with 22 additions and 14 deletions
19
src/main.rs
19
src/main.rs
|
|
@ -9,7 +9,7 @@ mod model;
|
|||
mod postprocessing;
|
||||
mod preprocessing;
|
||||
|
||||
use model::{ModelInfo, create_session, get_model_path};
|
||||
use model::{Model, create_session, get_model_path};
|
||||
use postprocessing::{apply_mask, create_side_by_side};
|
||||
use preprocessing::preprocess_image;
|
||||
|
||||
|
|
@ -24,10 +24,11 @@ struct Args {
|
|||
short,
|
||||
long,
|
||||
group = "model_selection",
|
||||
default_value = "bria",
|
||||
value_enum,
|
||||
default_value_t = Model::Bria,
|
||||
help = "Model to use: 'bria' or 'birefnet-lite' (mutually exclusive with --model-path)"
|
||||
)]
|
||||
model: String,
|
||||
model: Model,
|
||||
|
||||
#[arg(long, group = "model_selection", help = "Path to custom ONNX model")]
|
||||
model_path: Option<String>,
|
||||
|
|
@ -73,17 +74,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||
model_path
|
||||
} else {
|
||||
// Use built-in model
|
||||
let model_info = match args.model.as_str() {
|
||||
"bria" => ModelInfo::BRIA,
|
||||
"birefnet-lite" => ModelInfo::BIREFNET_LITE,
|
||||
_ => {
|
||||
return Err(format!(
|
||||
"Unknown model: {}. Available models: 'bria', 'birefnet-lite'",
|
||||
args.model
|
||||
)
|
||||
.into());
|
||||
}
|
||||
};
|
||||
let model_info = args.model.info();
|
||||
println!("Using model: {}", model_info.name);
|
||||
get_model_path(&model_info, None, args.offline)?
|
||||
};
|
||||
|
|
|
|||
17
src/model.rs
17
src/model.rs
|
|
@ -1,3 +1,4 @@
|
|||
use clap::ValueEnum;
|
||||
use ort::execution_providers::CUDAExecutionProvider;
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
|
@ -6,6 +7,22 @@ use std::fs;
|
|||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
|
||||
#[clap(rename_all = "kebab-case")]
|
||||
pub enum Model {
|
||||
BiRefNetLite,
|
||||
Bria,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn info(&self) -> &ModelInfo {
|
||||
match self {
|
||||
Model::BiRefNetLite => &ModelInfo::BIREFNET_LITE,
|
||||
Model::Bria => &ModelInfo::BRIA,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Model metadata
|
||||
pub struct ModelInfo {
|
||||
pub name: &'static str,
|
||||
|
|
|
|||
Loading…
Reference in a new issue