diff --git a/src/model.rs b/src/model.rs index 2f4d28d..dc17fff 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,3 +1,4 @@ +use ort::execution_providers::CUDAExecutionProvider; use ort::session::{Session, builder::GraphOptimizationLevel}; use sha2::{Digest, Sha256}; use std::error::Error; @@ -156,16 +157,17 @@ pub fn get_model_path( Ok(model_path) } -/// Create an ONNX Runtime session from model path +/// Create an ONNX Runtime session from model path with CUDA backend pub fn create_session(model_path: &Path) -> Result> { - println!("Loading model into ONNX Runtime..."); + println!("Loading model into ONNX Runtime with CUDA backend..."); let session = Session::builder()? + .with_execution_providers([CUDAExecutionProvider::default().build()])? .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(4)? .commit_from_file(model_path)?; - println!("Model loaded successfully!"); + println!("Model loaded successfully with CUDA!"); Ok(session) }