wip
This commit is contained in:
parent
22c9315b96
commit
609afde727
1 changed files with 34 additions and 7 deletions
41
src/main.rs
41
src/main.rs
|
|
@ -20,11 +20,19 @@ struct Args {
|
||||||
#[arg(help = "URL of the image to download and process")]
|
#[arg(help = "URL of the image to download and process")]
|
||||||
url: String,
|
url: String,
|
||||||
|
|
||||||
#[arg(long, help = "Path to custom ONNX model")]
|
#[arg(long, group = "model_selection", help = "Path to custom ONNX model")]
|
||||||
model_path: Option<String>,
|
model_path: Option<String>,
|
||||||
|
|
||||||
#[arg(long, help = "Skip model download, fail if not cached")]
|
#[arg(long, help = "Skip model download, fail if not cached")]
|
||||||
offline: bool,
|
offline: bool,
|
||||||
|
|
||||||
|
#[arg(
|
||||||
|
long,
|
||||||
|
group = "model_selection",
|
||||||
|
default_value = "bria",
|
||||||
|
help = "Model to use: 'bria' or 'birefnet-lite' (mutually exclusive with --model-path)"
|
||||||
|
)]
|
||||||
|
model: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[show_image::main]
|
#[show_image::main]
|
||||||
|
|
@ -53,12 +61,31 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
let (img_width, img_height) = img.dimensions();
|
let (img_width, img_height) = img.dimensions();
|
||||||
println!("Loaded image: {}x{}\n", img_width, img_height);
|
println!("Loaded image: {}x{}\n", img_width, img_height);
|
||||||
|
|
||||||
// Get model info
|
// Get model path - either from custom path or by selecting built-in model
|
||||||
let model_info = ModelInfo::BRIA;
|
let model_path = if let Some(custom_path) = args.model_path {
|
||||||
println!("Using model: {}", model_info.name);
|
// Custom model path provided, use it directly
|
||||||
|
let model_path = std::path::PathBuf::from(&custom_path);
|
||||||
// Get or download model
|
if !model_path.exists() {
|
||||||
let model_path = get_model_path(&model_info, args.model_path.as_deref(), args.offline)?;
|
return Err(format!("Custom model path does not exist: {}", custom_path).into());
|
||||||
|
}
|
||||||
|
println!("Using custom model: {}", custom_path);
|
||||||
|
model_path
|
||||||
|
} else {
|
||||||
|
// Use built-in model
|
||||||
|
let model_info = match args.model.as_str() {
|
||||||
|
"bria" => ModelInfo::BRIA,
|
||||||
|
"birefnet-lite" => ModelInfo::BIREFNET_LITE,
|
||||||
|
_ => {
|
||||||
|
return Err(format!(
|
||||||
|
"Unknown model: {}. Available models: 'bria', 'birefnet-lite'",
|
||||||
|
args.model
|
||||||
|
)
|
||||||
|
.into());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("Using model: {}", model_info.name);
|
||||||
|
get_model_path(&model_info, None, args.offline)?
|
||||||
|
};
|
||||||
|
|
||||||
// Create ONNX Runtime session
|
// Create ONNX Runtime session
|
||||||
let mut session = create_session(&model_path)?;
|
let mut session = create_session(&model_path)?;
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue