This commit is contained in:
Matthew Deville 2026-01-22 01:37:09 +01:00
parent 1fae113c5c
commit f6961c3177

View file

@ -1,3 +1,4 @@
use ort::execution_providers::CUDAExecutionProvider;
use ort::session::{Session, builder::GraphOptimizationLevel};
use sha2::{Digest, Sha256};
use std::error::Error;
@ -156,16 +157,17 @@ pub fn get_model_path(
Ok(model_path)
}
/// Create an ONNX Runtime session from model path
/// Create an ONNX Runtime session from model path with CUDA backend
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
println!("Loading model into ONNX Runtime...");
println!("Loading model into ONNX Runtime with CUDA backend...");
let session = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?
.commit_from_file(model_path)?;
println!("Model loaded successfully!");
println!("Model loaded successfully with CUDA!");
Ok(session)
}