From 25a4a204b9af8d04a323ce59c44a5583c1ad9e9b Mon Sep 17 00:00:00 2001 From: Matthew Deville Date: Thu, 23 Apr 2026 18:13:40 +0200 Subject: [PATCH] make cuda a feature --- Cargo.toml | 6 +++++- src/model.rs | 24 ++++++++++++++++-------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 410879c..083e332 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,13 +3,17 @@ name = "remove_background" version = "0.1.0" +[features] + default = [] + cuda = ["ort/cuda"] + [dependencies] anyhow = "1" clap = { version = "4", features = ["derive"] } hex = "0.4" image = "0.25" ndarray = "0.17" - ort = { version = "=2.0.0-rc.12", features = ["cuda"] } + ort = { version = "=2.0.0-rc.12" } reqwest = { version = "0.13", features = ["blocking"] } sha2 = "0.11" show-image = { version = "0.14", features = ["image"] } diff --git a/src/model.rs b/src/model.rs index 4778fe5..f577bac 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,9 +1,6 @@ use { clap::ValueEnum, - ort::{ - execution_providers::CUDAExecutionProvider, - session::{Session, builder::GraphOptimizationLevel}, - }, + ort::session::Session, sha2::{Digest, Sha256}, std::{ fs, @@ -16,6 +13,9 @@ use { use crate::error::{AppError, AppResult}; +#[cfg(feature = "cuda")] +use ort::{builder::GraphOptimizationLevel, ep::CUDAExecutionProvider}; + /// CLI-facing model selector. Concrete session metadata (URL, checksum, /// preprocessing params) lives on the `Session` trait impls in `sessions/`. #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] @@ -173,12 +173,19 @@ pub fn get_model_path( Ok(model_path) } -/// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU). +/// Create an ONNX Runtime session from a model path. +/// +/// Uses the CUDA execution provider when built with `--features cuda`; otherwise runs on CPU. pub fn create_session(model_path: &Path) -> AppResult { + #[cfg(feature = "cuda")] info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend"); + #[cfg(not(feature = "cuda"))] + info!(path = %model_path.display(), "Loading model into ONNX Runtime with CPU backend"); let start = Instant::now(); - let session = Session::builder()? + let mut builder = Session::builder()?; + #[cfg(feature = "cuda")] + let builder = builder .with_execution_providers([CUDAExecutionProvider::default().build()]) .map_err(|err| AppError::OrtSessionBuilder { message: err.to_string(), @@ -190,8 +197,9 @@ pub fn create_session(model_path: &Path) -> AppResult { .with_intra_threads(4) .map_err(|err| AppError::OrtSessionBuilder { message: err.to_string(), - })? - .commit_from_file(model_path)?; + })?; + + let session = builder.commit_from_file(model_path)?; info!( elapsed_ms = start.elapsed().as_millis() as u64,