wip
This commit is contained in:
parent
1fae113c5c
commit
f6961c3177
1 changed files with 5 additions and 3 deletions
|
|
@ -1,3 +1,4 @@
|
||||||
|
use ort::execution_providers::CUDAExecutionProvider;
|
||||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
|
@ -156,16 +157,17 @@ pub fn get_model_path(
|
||||||
Ok(model_path)
|
Ok(model_path)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an ONNX Runtime session from model path
|
/// Create an ONNX Runtime session from model path with CUDA backend
|
||||||
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
||||||
println!("Loading model into ONNX Runtime...");
|
println!("Loading model into ONNX Runtime with CUDA backend...");
|
||||||
|
|
||||||
let session = Session::builder()?
|
let session = Session::builder()?
|
||||||
|
.with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||||
.with_intra_threads(4)?
|
.with_intra_threads(4)?
|
||||||
.commit_from_file(model_path)?;
|
.commit_from_file(model_path)?;
|
||||||
|
|
||||||
println!("Model loaded successfully!");
|
println!("Model loaded successfully with CUDA!");
|
||||||
|
|
||||||
Ok(session)
|
Ok(session)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue