wip
This commit is contained in:
parent
1fae113c5c
commit
f6961c3177
1 changed files with 5 additions and 3 deletions
|
|
@ -1,3 +1,4 @@
|
|||
use ort::execution_providers::CUDAExecutionProvider;
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::error::Error;
|
||||
|
|
@ -156,16 +157,17 @@ pub fn get_model_path(
|
|||
Ok(model_path)
|
||||
}
|
||||
|
||||
/// Create an ONNX Runtime session from model path
|
||||
/// Create an ONNX Runtime session from model path with CUDA backend
|
||||
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
||||
println!("Loading model into ONNX Runtime...");
|
||||
println!("Loading model into ONNX Runtime with CUDA backend...");
|
||||
|
||||
let session = Session::builder()?
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_intra_threads(4)?
|
||||
.commit_from_file(model_path)?;
|
||||
|
||||
println!("Model loaded successfully!");
|
||||
println!("Model loaded successfully with CUDA!");
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue