This commit is contained in:
Matthew Deville 2026-01-22 01:37:09 +01:00
parent 1fae113c5c
commit f6961c3177

View file

@ -1,3 +1,4 @@
use ort::execution_providers::CUDAExecutionProvider;
use ort::session::{Session, builder::GraphOptimizationLevel}; use ort::session::{Session, builder::GraphOptimizationLevel};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use std::error::Error; use std::error::Error;
@ -156,16 +157,17 @@ pub fn get_model_path(
Ok(model_path) Ok(model_path)
} }
/// Create an ONNX Runtime session from model path /// Create an ONNX Runtime session from model path with CUDA backend
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> { pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
println!("Loading model into ONNX Runtime..."); println!("Loading model into ONNX Runtime with CUDA backend...");
let session = Session::builder()? let session = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
.with_optimization_level(GraphOptimizationLevel::Level3)? .with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)? .with_intra_threads(4)?
.commit_from_file(model_path)?; .commit_from_file(model_path)?;
println!("Model loaded successfully!"); println!("Model loaded successfully with CUDA!");
Ok(session) Ok(session)
} }