make cuda a feature
This commit is contained in:
parent
5660af24f7
commit
25a4a204b9
2 changed files with 21 additions and 9 deletions
|
|
@ -3,13 +3,17 @@
|
|||
name = "remove_background"
|
||||
version = "0.1.0"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["ort/cuda"]
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
hex = "0.4"
|
||||
image = "0.25"
|
||||
ndarray = "0.17"
|
||||
ort = { version = "=2.0.0-rc.12", features = ["cuda"] }
|
||||
ort = { version = "=2.0.0-rc.12" }
|
||||
reqwest = { version = "0.13", features = ["blocking"] }
|
||||
sha2 = "0.11"
|
||||
show-image = { version = "0.14", features = ["image"] }
|
||||
|
|
|
|||
24
src/model.rs
24
src/model.rs
|
|
@ -1,9 +1,6 @@
|
|||
use {
|
||||
clap::ValueEnum,
|
||||
ort::{
|
||||
execution_providers::CUDAExecutionProvider,
|
||||
session::{Session, builder::GraphOptimizationLevel},
|
||||
},
|
||||
ort::session::Session,
|
||||
sha2::{Digest, Sha256},
|
||||
std::{
|
||||
fs,
|
||||
|
|
@ -16,6 +13,9 @@ use {
|
|||
|
||||
use crate::error::{AppError, AppResult};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use ort::{builder::GraphOptimizationLevel, ep::CUDAExecutionProvider};
|
||||
|
||||
/// CLI-facing model selector. Concrete session metadata (URL, checksum,
|
||||
/// preprocessing params) lives on the `Session` trait impls in `sessions/`.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
|
||||
|
|
@ -173,12 +173,19 @@ pub fn get_model_path(
|
|||
Ok(model_path)
|
||||
}
|
||||
|
||||
/// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU).
|
||||
/// Create an ONNX Runtime session from a model path.
|
||||
///
|
||||
/// Uses the CUDA execution provider when built with `--features cuda`; otherwise runs on CPU.
|
||||
pub fn create_session(model_path: &Path) -> AppResult<Session> {
|
||||
#[cfg(feature = "cuda")]
|
||||
info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend");
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
info!(path = %model_path.display(), "Loading model into ONNX Runtime with CPU backend");
|
||||
let start = Instant::now();
|
||||
|
||||
let session = Session::builder()?
|
||||
let mut builder = Session::builder()?;
|
||||
#[cfg(feature = "cuda")]
|
||||
let builder = builder
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])
|
||||
.map_err(|err| AppError::OrtSessionBuilder {
|
||||
message: err.to_string(),
|
||||
|
|
@ -190,8 +197,9 @@ pub fn create_session(model_path: &Path) -> AppResult<Session> {
|
|||
.with_intra_threads(4)
|
||||
.map_err(|err| AppError::OrtSessionBuilder {
|
||||
message: err.to_string(),
|
||||
})?
|
||||
.commit_from_file(model_path)?;
|
||||
})?;
|
||||
|
||||
let session = builder.commit_from_file(model_path)?;
|
||||
|
||||
info!(
|
||||
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||
|
|
|
|||
Loading…
Reference in a new issue