typed models
This commit is contained in:
parent
6e965c4a18
commit
596cd28dcc
2 changed files with 22 additions and 14 deletions
19
src/main.rs
19
src/main.rs
|
|
@ -9,7 +9,7 @@ mod model;
|
||||||
mod postprocessing;
|
mod postprocessing;
|
||||||
mod preprocessing;
|
mod preprocessing;
|
||||||
|
|
||||||
use model::{ModelInfo, create_session, get_model_path};
|
use model::{Model, create_session, get_model_path};
|
||||||
use postprocessing::{apply_mask, create_side_by_side};
|
use postprocessing::{apply_mask, create_side_by_side};
|
||||||
use preprocessing::preprocess_image;
|
use preprocessing::preprocess_image;
|
||||||
|
|
||||||
|
|
@ -24,10 +24,11 @@ struct Args {
|
||||||
short,
|
short,
|
||||||
long,
|
long,
|
||||||
group = "model_selection",
|
group = "model_selection",
|
||||||
default_value = "bria",
|
value_enum,
|
||||||
|
default_value_t = Model::Bria,
|
||||||
help = "Model to use: 'bria' or 'birefnet-lite' (mutually exclusive with --model-path)"
|
help = "Model to use: 'bria' or 'birefnet-lite' (mutually exclusive with --model-path)"
|
||||||
)]
|
)]
|
||||||
model: String,
|
model: Model,
|
||||||
|
|
||||||
#[arg(long, group = "model_selection", help = "Path to custom ONNX model")]
|
#[arg(long, group = "model_selection", help = "Path to custom ONNX model")]
|
||||||
model_path: Option<String>,
|
model_path: Option<String>,
|
||||||
|
|
@ -73,17 +74,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
model_path
|
model_path
|
||||||
} else {
|
} else {
|
||||||
// Use built-in model
|
// Use built-in model
|
||||||
let model_info = match args.model.as_str() {
|
let model_info = args.model.info();
|
||||||
"bria" => ModelInfo::BRIA,
|
|
||||||
"birefnet-lite" => ModelInfo::BIREFNET_LITE,
|
|
||||||
_ => {
|
|
||||||
return Err(format!(
|
|
||||||
"Unknown model: {}. Available models: 'bria', 'birefnet-lite'",
|
|
||||||
args.model
|
|
||||||
)
|
|
||||||
.into());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
println!("Using model: {}", model_info.name);
|
println!("Using model: {}", model_info.name);
|
||||||
get_model_path(&model_info, None, args.offline)?
|
get_model_path(&model_info, None, args.offline)?
|
||||||
};
|
};
|
||||||
|
|
|
||||||
17
src/model.rs
17
src/model.rs
|
|
@ -1,3 +1,4 @@
|
||||||
|
use clap::ValueEnum;
|
||||||
use ort::execution_providers::CUDAExecutionProvider;
|
use ort::execution_providers::CUDAExecutionProvider;
|
||||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
|
|
@ -6,6 +7,22 @@ use std::fs;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
#[clap(rename_all = "kebab-case")]
|
||||||
|
pub enum Model {
|
||||||
|
BiRefNetLite,
|
||||||
|
Bria,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn info(&self) -> &ModelInfo {
|
||||||
|
match self {
|
||||||
|
Model::BiRefNetLite => &ModelInfo::BIREFNET_LITE,
|
||||||
|
Model::Bria => &ModelInfo::BRIA,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Model metadata
|
/// Model metadata
|
||||||
pub struct ModelInfo {
|
pub struct ModelInfo {
|
||||||
pub name: &'static str,
|
pub name: &'static str,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue