logging
This commit is contained in:
parent
8d2fad9106
commit
31d84c99d7
6 changed files with 235 additions and 57 deletions
112
Cargo.lock
generated
112
Cargo.lock
generated
|
|
@ -45,6 +45,15 @@ dependencies = [
|
|||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aligned"
|
||||
version = "0.4.3"
|
||||
|
|
@ -1776,6 +1785,15 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matchers"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
|
||||
dependencies = [
|
||||
"regex-automata",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.10"
|
||||
|
|
@ -2021,6 +2039,15 @@ version = "0.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8"
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.50.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.4.6"
|
||||
|
|
@ -2702,6 +2729,23 @@ dependencies = [
|
|||
"bitflags 2.11.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.8.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
|
||||
|
||||
[[package]]
|
||||
name = "remove_background"
|
||||
version = "0.1.0"
|
||||
|
|
@ -2714,6 +2758,8 @@ dependencies = [
|
|||
"reqwest",
|
||||
"sha2",
|
||||
"show-image",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3029,6 +3075,15 @@ dependencies = [
|
|||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sharded-slab"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.3.0"
|
||||
|
|
@ -3303,6 +3358,15 @@ dependencies = [
|
|||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thread_local"
|
||||
version = "1.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiff"
|
||||
version = "0.11.3"
|
||||
|
|
@ -3473,9 +3537,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
|
||||
dependencies = [
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-attributes"
|
||||
version = "0.1.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-core"
|
||||
version = "0.1.36"
|
||||
|
|
@ -3483,6 +3559,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"valuable",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-log"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
|
||||
dependencies = [
|
||||
"log",
|
||||
"once_cell",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-subscriber"
|
||||
version = "0.3.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319"
|
||||
dependencies = [
|
||||
"matchers",
|
||||
"nu-ansi-term",
|
||||
"once_cell",
|
||||
"regex-automata",
|
||||
"sharded-slab",
|
||||
"smallvec",
|
||||
"thread_local",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tracing-log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3598,6 +3704,12 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
|
||||
|
||||
[[package]]
|
||||
name = "vcpkg"
|
||||
version = "0.2.15"
|
||||
|
|
|
|||
|
|
@ -12,3 +12,5 @@
|
|||
reqwest = { version = "0.13", features = ["blocking"] }
|
||||
sha2 = "0.11"
|
||||
show-image = { version = "0.14", features = ["image"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
|
||||
|
|
|
|||
84
src/main.rs
84
src/main.rs
|
|
@ -13,6 +13,8 @@ use {
|
|||
io::{Cursor, Read, Write, stdin, stdout},
|
||||
path::PathBuf,
|
||||
},
|
||||
tracing::{debug, info},
|
||||
tracing_subscriber::{EnvFilter, fmt},
|
||||
};
|
||||
|
||||
#[derive(Subcommand)]
|
||||
|
|
@ -67,8 +69,20 @@ struct Args {
|
|||
|
||||
#[show_image::main]
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
fmt()
|
||||
.with_env_filter(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")))
|
||||
.with_writer(std::io::stderr)
|
||||
.init();
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
info!(
|
||||
model = ?args.model,
|
||||
custom_model = args.model_path.as_deref().unwrap_or("<none>"),
|
||||
offline = args.offline,
|
||||
"Starting remove_background"
|
||||
);
|
||||
|
||||
let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)?;
|
||||
|
||||
match args.command {
|
||||
|
|
@ -78,8 +92,10 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||
output_file,
|
||||
} => {
|
||||
let img = if let Some(input_file) = input_file {
|
||||
debug!(path = %input_file.display(), "Reading input image from file");
|
||||
ImageReader::open(input_file)?.decode()?
|
||||
} else {
|
||||
debug!("Reading input image from stdin");
|
||||
let mut bytes = Vec::new();
|
||||
if stdin().lock().read_to_end(&mut bytes)? == 0 {
|
||||
return Err("No input data provided via stdin".into());
|
||||
|
|
@ -90,15 +106,9 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||
};
|
||||
|
||||
let (img_width, img_height) = img.dimensions();
|
||||
println!("Loaded image: {}x{}\n", img_width, img_height);
|
||||
info!(width = img_width, height = img_height, "Loaded image");
|
||||
|
||||
println!("Running inference...");
|
||||
let mask = session.predict(&img)?;
|
||||
println!(
|
||||
"Inference complete! Mask dimensions: {}x{}",
|
||||
mask.dimensions().0,
|
||||
mask.dimensions().1
|
||||
);
|
||||
let result_rgba = apply_mask(&img, &mask)?;
|
||||
|
||||
if debug {
|
||||
|
|
@ -106,8 +116,10 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||
}
|
||||
|
||||
if let Some(output_file) = output_file {
|
||||
result_rgba.save(output_file)?;
|
||||
result_rgba.save(&output_file)?;
|
||||
info!(path = %output_file.display(), "Wrote output image to file");
|
||||
} else {
|
||||
debug!("Writing output image to stdout (PNG)");
|
||||
let mut buffer = Cursor::new(Vec::new());
|
||||
result_rgba.write_to(&mut buffer, image::ImageFormat::Png)?;
|
||||
let mut stdout = stdout().lock();
|
||||
|
|
@ -119,20 +131,49 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||
output_directory,
|
||||
} => {
|
||||
fs::create_dir_all(&output_directory)?;
|
||||
let files = fs::read_dir(input_directory)?;
|
||||
for file in files {
|
||||
let file = file?;
|
||||
let path = file.path();
|
||||
let img = ImageReader::open(&path)?.decode()?;
|
||||
info!(
|
||||
input = %input_directory.display(),
|
||||
output = %output_directory.display(),
|
||||
"Starting batch processing"
|
||||
);
|
||||
|
||||
let mut processed: usize = 0;
|
||||
let mut failed: usize = 0;
|
||||
for entry in fs::read_dir(&input_directory)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
let span = tracing::info_span!("batch_item", path = %path.display());
|
||||
let _enter = span.enter();
|
||||
|
||||
let img = match ImageReader::open(&path) {
|
||||
Ok(reader) => match reader.decode() {
|
||||
Ok(img) => img,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to decode image, skipping");
|
||||
failed += 1;
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to open image, skipping");
|
||||
failed += 1;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let mask = session.predict(&img)?;
|
||||
let result_rgba = apply_mask(&img, &mask)?;
|
||||
let output_path = output_directory.join(path.file_name().unwrap());
|
||||
result_rgba.save(&output_path)?;
|
||||
println!("Processed image saved to: {}", output_path.display());
|
||||
processed += 1;
|
||||
info!(output = %output_path.display(), "Processed image saved");
|
||||
}
|
||||
|
||||
info!(processed, failed, "Batch processing complete");
|
||||
}
|
||||
}
|
||||
|
||||
info!("Shutting down");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -140,7 +181,7 @@ fn debug_mode(
|
|||
img: &image::DynamicImage,
|
||||
result_rgba: &image::RgbaImage,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
println!("Creating side-by-side comparison...");
|
||||
info!("Creating side-by-side comparison");
|
||||
let composite = create_side_by_side(img, result_rgba)?;
|
||||
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
|
||||
|
||||
|
|
@ -156,21 +197,18 @@ fn debug_mode(
|
|||
.map_err(|e| e.to_string())?,
|
||||
)?;
|
||||
|
||||
println!("\n=== Done! ===");
|
||||
println!(
|
||||
"Displaying side-by-side comparison ({}x{}):",
|
||||
comp_width, comp_height
|
||||
info!(
|
||||
width = comp_width,
|
||||
height = comp_height,
|
||||
"Displaying side-by-side comparison (left: original, right: background removed on checkered bg); press ESC to close"
|
||||
);
|
||||
println!(" Left: Original image");
|
||||
println!(" Right: Background removed (shown on checkered background)");
|
||||
println!("\nPress ESC to close the window.");
|
||||
|
||||
for event in window.event_channel()? {
|
||||
if let event::WindowEvent::KeyboardInput(event) = event
|
||||
&& event.input.key_code == Some(event::VirtualKeyCode::Escape)
|
||||
&& event.input.state.is_pressed()
|
||||
{
|
||||
println!("ESC pressed, closing...");
|
||||
debug!("ESC pressed, closing window");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
55
src/model.rs
55
src/model.rs
|
|
@ -8,9 +8,11 @@ use {
|
|||
std::{
|
||||
error::Error,
|
||||
fs,
|
||||
io::{Read, Write},
|
||||
io::Read,
|
||||
path::{Path, PathBuf},
|
||||
time::Instant,
|
||||
},
|
||||
tracing::{debug, info, warn},
|
||||
};
|
||||
|
||||
/// CLI-facing model selector. Concrete session metadata (URL, checksum,
|
||||
|
|
@ -34,7 +36,8 @@ fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
|
|||
}
|
||||
|
||||
fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
|
||||
println!("Downloading model from {}...", url);
|
||||
info!(%url, dest = %dest.display(), "Downloading model");
|
||||
let start = Instant::now();
|
||||
|
||||
let client = reqwest::blocking::Client::builder()
|
||||
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36")
|
||||
|
|
@ -46,9 +49,10 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
|
|||
return Err(format!("Failed to download model: HTTP {}", response.status()).into());
|
||||
}
|
||||
|
||||
let mut file = fs::File::create(dest)?;
|
||||
let mut file = std::io::BufWriter::new(fs::File::create(dest)?);
|
||||
let total_size = response.content_length().unwrap_or(0);
|
||||
let mut downloaded = 0u64;
|
||||
let mut last_reported_pct: u32 = 0;
|
||||
|
||||
let mut buffer = vec![0; 8192];
|
||||
loop {
|
||||
|
|
@ -56,21 +60,24 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
|
|||
if bytes_read == 0 {
|
||||
break;
|
||||
}
|
||||
file.write_all(&buffer[..bytes_read])?;
|
||||
std::io::Write::write_all(&mut file, &buffer[..bytes_read])?;
|
||||
downloaded += bytes_read as u64;
|
||||
|
||||
if total_size > 0 {
|
||||
let progress = (downloaded as f64 / total_size as f64 * 100.0) as u32;
|
||||
print!("\rDownloading... {}%", progress);
|
||||
std::io::stdout().flush()?;
|
||||
// Only emit every 10% to avoid log spam while still giving feedback.
|
||||
if progress >= last_reported_pct + 10 {
|
||||
last_reported_pct = progress - (progress % 10);
|
||||
debug!(percent = last_reported_pct, downloaded, total = total_size, "Download progress");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if total_size > 0 {
|
||||
println!("\rDownload complete! ");
|
||||
} else {
|
||||
println!("Download complete! ({} bytes)", downloaded);
|
||||
}
|
||||
info!(
|
||||
bytes = downloaded,
|
||||
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||
"Download complete"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -110,7 +117,7 @@ pub fn get_model_path(
|
|||
if !model_path.exists() {
|
||||
return Err(format!("Custom model path does not exist: {}", path).into());
|
||||
}
|
||||
println!("Using custom model: {}", path);
|
||||
info!(%path, "Using custom model");
|
||||
return Ok(model_path);
|
||||
}
|
||||
|
||||
|
|
@ -119,15 +126,14 @@ pub fn get_model_path(
|
|||
let model_path = cache_dir.join(&model_filename);
|
||||
|
||||
if model_path.exists() {
|
||||
println!("Using cached model: {}", model_path.display());
|
||||
info!(path = %model_path.display(), "Using cached model");
|
||||
|
||||
if let Some(expected_hash) = sha256 {
|
||||
print!("Verifying model integrity... ");
|
||||
std::io::stdout().flush()?;
|
||||
debug!("Verifying cached model integrity");
|
||||
if verify_hash(&model_path, expected_hash)? {
|
||||
println!("OK");
|
||||
debug!("Cached model hash OK");
|
||||
} else {
|
||||
println!("FAILED");
|
||||
warn!(path = %model_path.display(), "Cached model hash verification failed");
|
||||
return Err(
|
||||
"Model hash verification failed. Try deleting the cached model.".into(),
|
||||
);
|
||||
|
|
@ -145,14 +151,13 @@ pub fn get_model_path(
|
|||
.into());
|
||||
}
|
||||
|
||||
println!("Model not found in cache, downloading...");
|
||||
info!("Model not found in cache, downloading");
|
||||
download_file(url, &model_path)?;
|
||||
|
||||
if let Some(expected_hash) = sha256 {
|
||||
print!("Verifying downloaded model... ");
|
||||
std::io::stdout().flush()?;
|
||||
debug!("Verifying downloaded model");
|
||||
if verify_hash(&model_path, expected_hash)? {
|
||||
println!("OK");
|
||||
debug!("Downloaded model hash OK");
|
||||
} else {
|
||||
fs::remove_file(&model_path)?;
|
||||
return Err("Downloaded model hash verification failed".into());
|
||||
|
|
@ -164,7 +169,8 @@ pub fn get_model_path(
|
|||
|
||||
/// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU).
|
||||
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
||||
println!("Loading model into ONNX Runtime with CUDA backend...");
|
||||
info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend");
|
||||
let start = Instant::now();
|
||||
|
||||
let session = Session::builder()?
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
|
|
@ -172,7 +178,10 @@ pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
|||
.with_intra_threads(4)?
|
||||
.commit_from_file(model_path)?;
|
||||
|
||||
println!("Model loaded successfully with CUDA!");
|
||||
info!(
|
||||
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||
"Model loaded successfully"
|
||||
);
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use {
|
||||
image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType},
|
||||
std::error::Error,
|
||||
tracing::debug,
|
||||
};
|
||||
|
||||
/// Compose `original` with `mask` as the alpha channel and return an RGBA image.
|
||||
|
|
@ -8,17 +9,18 @@ use {
|
|||
/// The mask is expected to already be grayscale. If its dimensions differ from
|
||||
/// the original, it is resized with LANCZOS3.
|
||||
pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage, Box<dyn Error>> {
|
||||
println!("Applying mask to remove background...");
|
||||
|
||||
let (orig_width, orig_height) = original.dimensions();
|
||||
let (mask_width, mask_height) = mask.dimensions();
|
||||
|
||||
println!("Mask dimensions: {}x{}", mask_width, mask_height);
|
||||
println!("Original dimensions: {}x{}", orig_width, orig_height);
|
||||
|
||||
let resized_mask: std::borrow::Cow<'_, GrayImage> =
|
||||
if mask_width != orig_width || mask_height != orig_height {
|
||||
println!("Resizing mask to match original image...");
|
||||
debug!(
|
||||
mask_width,
|
||||
mask_height,
|
||||
orig_width,
|
||||
orig_height,
|
||||
"Resizing mask to match original image"
|
||||
);
|
||||
std::borrow::Cow::Owned(image::imageops::resize(
|
||||
mask,
|
||||
orig_width,
|
||||
|
|
@ -45,7 +47,7 @@ pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage
|
|||
}
|
||||
}
|
||||
|
||||
println!("Background removal complete!");
|
||||
debug!("Mask applied to image");
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ use {
|
|||
image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType},
|
||||
ndarray::{Array4, IntoDimension},
|
||||
ort::{session::Session as OrtSession, value::Tensor},
|
||||
std::error::Error,
|
||||
std::{error::Error, time::Instant},
|
||||
tracing::{debug, info},
|
||||
};
|
||||
|
||||
use crate::model::Model;
|
||||
|
|
@ -105,13 +106,22 @@ pub trait Session {
|
|||
/// 5. scale to `u8`, resize to original image dimensions
|
||||
fn predict(&mut self, img: &DynamicImage) -> Result<GrayImage, Box<dyn Error>> {
|
||||
let (orig_w, orig_h) = img.dimensions();
|
||||
|
||||
let preprocess_start = Instant::now();
|
||||
let input = self.normalize(img)?;
|
||||
debug!(elapsed_ms = preprocess_start.elapsed().as_millis() as u64, "Preprocessing complete");
|
||||
let apply_sigmoid = self.apply_sigmoid();
|
||||
|
||||
let input_name = self.input_name().to_string();
|
||||
debug!(%input_name, "Running ONNX session");
|
||||
let infer_start = Instant::now();
|
||||
let outputs = self
|
||||
.inner()
|
||||
.run(ort::inputs![input_name => Tensor::from_array(input)?])?;
|
||||
info!(
|
||||
elapsed_ms = infer_start.elapsed().as_millis() as u64,
|
||||
"ONNX inference complete"
|
||||
);
|
||||
|
||||
let output = &outputs[0];
|
||||
let (shape, data) = output.try_extract_tensor::<f32>()?;
|
||||
|
|
@ -177,6 +187,11 @@ pub trait Session {
|
|||
}
|
||||
|
||||
let mask = image::imageops::resize(&mask, orig_w, orig_h, FilterType::Lanczos3);
|
||||
debug!(
|
||||
mask_width = orig_w,
|
||||
mask_height = orig_h,
|
||||
"Mask resized to original image dimensions"
|
||||
);
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
|
|
@ -191,7 +206,7 @@ pub fn init_session(
|
|||
if !path.exists() {
|
||||
return Err(format!("Custom model path does not exist: {}", custom_path).into());
|
||||
}
|
||||
println!("Using custom model: {}", custom_path);
|
||||
info!(path = %custom_path, "Using custom model");
|
||||
match model {
|
||||
Model::Bria => Box::new(BriaSession::from_model_path(&path)?),
|
||||
Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?),
|
||||
|
|
@ -199,11 +214,11 @@ pub fn init_session(
|
|||
} else {
|
||||
match model {
|
||||
Model::Bria => {
|
||||
println!("Using model: bria-rmbg");
|
||||
info!(model = "bria-rmbg", "Using model");
|
||||
Box::new(BriaSession::new(offline)?)
|
||||
}
|
||||
Model::BiRefNetLite => {
|
||||
println!("Using model: birefnet-general-lite");
|
||||
info!(model = "birefnet-general-lite", "Using model");
|
||||
Box::new(BiRefNetLiteSession::new(offline)?)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue