make cuda a feature

This commit is contained in:
Matthew Deville 2026-04-23 18:13:40 +02:00
parent 5660af24f7
commit 25a4a204b9
2 changed files with 21 additions and 9 deletions

View file

@ -3,13 +3,17 @@
name = "remove_background" name = "remove_background"
version = "0.1.0" version = "0.1.0"
[features]
default = []
cuda = ["ort/cuda"]
[dependencies] [dependencies]
anyhow = "1" anyhow = "1"
clap = { version = "4", features = ["derive"] } clap = { version = "4", features = ["derive"] }
hex = "0.4" hex = "0.4"
image = "0.25" image = "0.25"
ndarray = "0.17" ndarray = "0.17"
ort = { version = "=2.0.0-rc.12", features = ["cuda"] } ort = { version = "=2.0.0-rc.12" }
reqwest = { version = "0.13", features = ["blocking"] } reqwest = { version = "0.13", features = ["blocking"] }
sha2 = "0.11" sha2 = "0.11"
show-image = { version = "0.14", features = ["image"] } show-image = { version = "0.14", features = ["image"] }

View file

@ -1,9 +1,6 @@
use { use {
clap::ValueEnum, clap::ValueEnum,
ort::{ ort::session::Session,
execution_providers::CUDAExecutionProvider,
session::{Session, builder::GraphOptimizationLevel},
},
sha2::{Digest, Sha256}, sha2::{Digest, Sha256},
std::{ std::{
fs, fs,
@ -16,6 +13,9 @@ use {
use crate::error::{AppError, AppResult}; use crate::error::{AppError, AppResult};
#[cfg(feature = "cuda")]
use ort::{builder::GraphOptimizationLevel, ep::CUDAExecutionProvider};
/// CLI-facing model selector. Concrete session metadata (URL, checksum, /// CLI-facing model selector. Concrete session metadata (URL, checksum,
/// preprocessing params) lives on the `Session` trait impls in `sessions/`. /// preprocessing params) lives on the `Session` trait impls in `sessions/`.
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
@ -173,12 +173,19 @@ pub fn get_model_path(
Ok(model_path) Ok(model_path)
} }
/// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU). /// Create an ONNX Runtime session from a model path.
///
/// Uses the CUDA execution provider when built with `--features cuda`; otherwise runs on CPU.
pub fn create_session(model_path: &Path) -> AppResult<Session> { pub fn create_session(model_path: &Path) -> AppResult<Session> {
#[cfg(feature = "cuda")]
info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend"); info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend");
#[cfg(not(feature = "cuda"))]
info!(path = %model_path.display(), "Loading model into ONNX Runtime with CPU backend");
let start = Instant::now(); let start = Instant::now();
let session = Session::builder()? let mut builder = Session::builder()?;
#[cfg(feature = "cuda")]
let builder = builder
.with_execution_providers([CUDAExecutionProvider::default().build()]) .with_execution_providers([CUDAExecutionProvider::default().build()])
.map_err(|err| AppError::OrtSessionBuilder { .map_err(|err| AppError::OrtSessionBuilder {
message: err.to_string(), message: err.to_string(),
@ -190,8 +197,9 @@ pub fn create_session(model_path: &Path) -> AppResult<Session> {
.with_intra_threads(4) .with_intra_threads(4)
.map_err(|err| AppError::OrtSessionBuilder { .map_err(|err| AppError::OrtSessionBuilder {
message: err.to_string(), message: err.to_string(),
})? })?;
.commit_from_file(model_path)?;
let session = builder.commit_from_file(model_path)?;
info!( info!(
elapsed_ms = start.elapsed().as_millis() as u64, elapsed_ms = start.elapsed().as_millis() as u64,