This commit is contained in:
Matthew Deville 2026-04-22 22:30:41 +02:00
parent 8d2fad9106
commit 31d84c99d7
6 changed files with 235 additions and 57 deletions

112
Cargo.lock generated
View file

@ -45,6 +45,15 @@ dependencies = [
"zerocopy", "zerocopy",
] ]
[[package]]
name = "aho-corasick"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "aligned" name = "aligned"
version = "0.4.3" version = "0.4.3"
@ -1776,6 +1785,15 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "matchers"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
dependencies = [
"regex-automata",
]
[[package]] [[package]]
name = "matrixmultiply" name = "matrixmultiply"
version = "0.3.10" version = "0.3.10"
@ -2021,6 +2039,15 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8"
[[package]]
name = "nu-ansi-term"
version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.61.2",
]
[[package]] [[package]]
name = "num-bigint" name = "num-bigint"
version = "0.4.6" version = "0.4.6"
@ -2702,6 +2729,23 @@ dependencies = [
"bitflags 2.11.1", "bitflags 2.11.1",
] ]
[[package]]
name = "regex-automata"
version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
[[package]] [[package]]
name = "remove_background" name = "remove_background"
version = "0.1.0" version = "0.1.0"
@ -2714,6 +2758,8 @@ dependencies = [
"reqwest", "reqwest",
"sha2", "sha2",
"show-image", "show-image",
"tracing",
"tracing-subscriber",
] ]
[[package]] [[package]]
@ -3029,6 +3075,15 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]] [[package]]
name = "shlex" name = "shlex"
version = "1.3.0" version = "1.3.0"
@ -3303,6 +3358,15 @@ dependencies = [
"syn 2.0.117", "syn 2.0.117",
] ]
[[package]]
name = "thread_local"
version = "1.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185"
dependencies = [
"cfg-if",
]
[[package]] [[package]]
name = "tiff" name = "tiff"
version = "0.11.3" version = "0.11.3"
@ -3473,9 +3537,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
dependencies = [ dependencies = [
"pin-project-lite", "pin-project-lite",
"tracing-attributes",
"tracing-core", "tracing-core",
] ]
[[package]]
name = "tracing-attributes"
version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]] [[package]]
name = "tracing-core" name = "tracing-core"
version = "0.1.36" version = "0.1.36"
@ -3483,6 +3559,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a"
dependencies = [ dependencies = [
"once_cell", "once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex-automata",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
] ]
[[package]] [[package]]
@ -3598,6 +3704,12 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "valuable"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]] [[package]]
name = "vcpkg" name = "vcpkg"
version = "0.2.15" version = "0.2.15"

View file

@ -12,3 +12,5 @@
reqwest = { version = "0.13", features = ["blocking"] } reqwest = { version = "0.13", features = ["blocking"] }
sha2 = "0.11" sha2 = "0.11"
show-image = { version = "0.14", features = ["image"] } show-image = { version = "0.14", features = ["image"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }

View file

@ -13,6 +13,8 @@ use {
io::{Cursor, Read, Write, stdin, stdout}, io::{Cursor, Read, Write, stdin, stdout},
path::PathBuf, path::PathBuf,
}, },
tracing::{debug, info},
tracing_subscriber::{EnvFilter, fmt},
}; };
#[derive(Subcommand)] #[derive(Subcommand)]
@ -67,8 +69,20 @@ struct Args {
#[show_image::main] #[show_image::main]
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
fmt()
.with_env_filter(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")))
.with_writer(std::io::stderr)
.init();
let args = Args::parse(); let args = Args::parse();
info!(
model = ?args.model,
custom_model = args.model_path.as_deref().unwrap_or("<none>"),
offline = args.offline,
"Starting remove_background"
);
let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)?; let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)?;
match args.command { match args.command {
@ -78,8 +92,10 @@ fn main() -> Result<(), Box<dyn Error>> {
output_file, output_file,
} => { } => {
let img = if let Some(input_file) = input_file { let img = if let Some(input_file) = input_file {
debug!(path = %input_file.display(), "Reading input image from file");
ImageReader::open(input_file)?.decode()? ImageReader::open(input_file)?.decode()?
} else { } else {
debug!("Reading input image from stdin");
let mut bytes = Vec::new(); let mut bytes = Vec::new();
if stdin().lock().read_to_end(&mut bytes)? == 0 { if stdin().lock().read_to_end(&mut bytes)? == 0 {
return Err("No input data provided via stdin".into()); return Err("No input data provided via stdin".into());
@ -90,15 +106,9 @@ fn main() -> Result<(), Box<dyn Error>> {
}; };
let (img_width, img_height) = img.dimensions(); let (img_width, img_height) = img.dimensions();
println!("Loaded image: {}x{}\n", img_width, img_height); info!(width = img_width, height = img_height, "Loaded image");
println!("Running inference...");
let mask = session.predict(&img)?; let mask = session.predict(&img)?;
println!(
"Inference complete! Mask dimensions: {}x{}",
mask.dimensions().0,
mask.dimensions().1
);
let result_rgba = apply_mask(&img, &mask)?; let result_rgba = apply_mask(&img, &mask)?;
if debug { if debug {
@ -106,8 +116,10 @@ fn main() -> Result<(), Box<dyn Error>> {
} }
if let Some(output_file) = output_file { if let Some(output_file) = output_file {
result_rgba.save(output_file)?; result_rgba.save(&output_file)?;
info!(path = %output_file.display(), "Wrote output image to file");
} else { } else {
debug!("Writing output image to stdout (PNG)");
let mut buffer = Cursor::new(Vec::new()); let mut buffer = Cursor::new(Vec::new());
result_rgba.write_to(&mut buffer, image::ImageFormat::Png)?; result_rgba.write_to(&mut buffer, image::ImageFormat::Png)?;
let mut stdout = stdout().lock(); let mut stdout = stdout().lock();
@ -119,20 +131,49 @@ fn main() -> Result<(), Box<dyn Error>> {
output_directory, output_directory,
} => { } => {
fs::create_dir_all(&output_directory)?; fs::create_dir_all(&output_directory)?;
let files = fs::read_dir(input_directory)?; info!(
for file in files { input = %input_directory.display(),
let file = file?; output = %output_directory.display(),
let path = file.path(); "Starting batch processing"
let img = ImageReader::open(&path)?.decode()?; );
let mut processed: usize = 0;
let mut failed: usize = 0;
for entry in fs::read_dir(&input_directory)? {
let entry = entry?;
let path = entry.path();
let span = tracing::info_span!("batch_item", path = %path.display());
let _enter = span.enter();
let img = match ImageReader::open(&path) {
Ok(reader) => match reader.decode() {
Ok(img) => img,
Err(e) => {
tracing::warn!(error = %e, "Failed to decode image, skipping");
failed += 1;
continue;
}
},
Err(e) => {
tracing::warn!(error = %e, "Failed to open image, skipping");
failed += 1;
continue;
}
};
let mask = session.predict(&img)?; let mask = session.predict(&img)?;
let result_rgba = apply_mask(&img, &mask)?; let result_rgba = apply_mask(&img, &mask)?;
let output_path = output_directory.join(path.file_name().unwrap()); let output_path = output_directory.join(path.file_name().unwrap());
result_rgba.save(&output_path)?; result_rgba.save(&output_path)?;
println!("Processed image saved to: {}", output_path.display()); processed += 1;
info!(output = %output_path.display(), "Processed image saved");
} }
info!(processed, failed, "Batch processing complete");
} }
} }
info!("Shutting down");
Ok(()) Ok(())
} }
@ -140,7 +181,7 @@ fn debug_mode(
img: &image::DynamicImage, img: &image::DynamicImage,
result_rgba: &image::RgbaImage, result_rgba: &image::RgbaImage,
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
println!("Creating side-by-side comparison..."); info!("Creating side-by-side comparison");
let composite = create_side_by_side(img, result_rgba)?; let composite = create_side_by_side(img, result_rgba)?;
let composite_dynamic = image::DynamicImage::ImageRgba8(composite); let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
@ -156,21 +197,18 @@ fn debug_mode(
.map_err(|e| e.to_string())?, .map_err(|e| e.to_string())?,
)?; )?;
println!("\n=== Done! ==="); info!(
println!( width = comp_width,
"Displaying side-by-side comparison ({}x{}):", height = comp_height,
comp_width, comp_height "Displaying side-by-side comparison (left: original, right: background removed on checkered bg); press ESC to close"
); );
println!(" Left: Original image");
println!(" Right: Background removed (shown on checkered background)");
println!("\nPress ESC to close the window.");
for event in window.event_channel()? { for event in window.event_channel()? {
if let event::WindowEvent::KeyboardInput(event) = event if let event::WindowEvent::KeyboardInput(event) = event
&& event.input.key_code == Some(event::VirtualKeyCode::Escape) && event.input.key_code == Some(event::VirtualKeyCode::Escape)
&& event.input.state.is_pressed() && event.input.state.is_pressed()
{ {
println!("ESC pressed, closing..."); debug!("ESC pressed, closing window");
break; break;
} }
} }

View file

@ -8,9 +8,11 @@ use {
std::{ std::{
error::Error, error::Error,
fs, fs,
io::{Read, Write}, io::Read,
path::{Path, PathBuf}, path::{Path, PathBuf},
time::Instant,
}, },
tracing::{debug, info, warn},
}; };
/// CLI-facing model selector. Concrete session metadata (URL, checksum, /// CLI-facing model selector. Concrete session metadata (URL, checksum,
@ -34,7 +36,8 @@ fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
} }
fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> { fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
println!("Downloading model from {}...", url); info!(%url, dest = %dest.display(), "Downloading model");
let start = Instant::now();
let client = reqwest::blocking::Client::builder() let client = reqwest::blocking::Client::builder()
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36") .user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36")
@ -46,9 +49,10 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
return Err(format!("Failed to download model: HTTP {}", response.status()).into()); return Err(format!("Failed to download model: HTTP {}", response.status()).into());
} }
let mut file = fs::File::create(dest)?; let mut file = std::io::BufWriter::new(fs::File::create(dest)?);
let total_size = response.content_length().unwrap_or(0); let total_size = response.content_length().unwrap_or(0);
let mut downloaded = 0u64; let mut downloaded = 0u64;
let mut last_reported_pct: u32 = 0;
let mut buffer = vec![0; 8192]; let mut buffer = vec![0; 8192];
loop { loop {
@ -56,21 +60,24 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
if bytes_read == 0 { if bytes_read == 0 {
break; break;
} }
file.write_all(&buffer[..bytes_read])?; std::io::Write::write_all(&mut file, &buffer[..bytes_read])?;
downloaded += bytes_read as u64; downloaded += bytes_read as u64;
if total_size > 0 { if total_size > 0 {
let progress = (downloaded as f64 / total_size as f64 * 100.0) as u32; let progress = (downloaded as f64 / total_size as f64 * 100.0) as u32;
print!("\rDownloading... {}%", progress); // Only emit every 10% to avoid log spam while still giving feedback.
std::io::stdout().flush()?; if progress >= last_reported_pct + 10 {
last_reported_pct = progress - (progress % 10);
debug!(percent = last_reported_pct, downloaded, total = total_size, "Download progress");
}
} }
} }
if total_size > 0 { info!(
println!("\rDownload complete! "); bytes = downloaded,
} else { elapsed_ms = start.elapsed().as_millis() as u64,
println!("Download complete! ({} bytes)", downloaded); "Download complete"
} );
Ok(()) Ok(())
} }
@ -110,7 +117,7 @@ pub fn get_model_path(
if !model_path.exists() { if !model_path.exists() {
return Err(format!("Custom model path does not exist: {}", path).into()); return Err(format!("Custom model path does not exist: {}", path).into());
} }
println!("Using custom model: {}", path); info!(%path, "Using custom model");
return Ok(model_path); return Ok(model_path);
} }
@ -119,15 +126,14 @@ pub fn get_model_path(
let model_path = cache_dir.join(&model_filename); let model_path = cache_dir.join(&model_filename);
if model_path.exists() { if model_path.exists() {
println!("Using cached model: {}", model_path.display()); info!(path = %model_path.display(), "Using cached model");
if let Some(expected_hash) = sha256 { if let Some(expected_hash) = sha256 {
print!("Verifying model integrity... "); debug!("Verifying cached model integrity");
std::io::stdout().flush()?;
if verify_hash(&model_path, expected_hash)? { if verify_hash(&model_path, expected_hash)? {
println!("OK"); debug!("Cached model hash OK");
} else { } else {
println!("FAILED"); warn!(path = %model_path.display(), "Cached model hash verification failed");
return Err( return Err(
"Model hash verification failed. Try deleting the cached model.".into(), "Model hash verification failed. Try deleting the cached model.".into(),
); );
@ -145,14 +151,13 @@ pub fn get_model_path(
.into()); .into());
} }
println!("Model not found in cache, downloading..."); info!("Model not found in cache, downloading");
download_file(url, &model_path)?; download_file(url, &model_path)?;
if let Some(expected_hash) = sha256 { if let Some(expected_hash) = sha256 {
print!("Verifying downloaded model... "); debug!("Verifying downloaded model");
std::io::stdout().flush()?;
if verify_hash(&model_path, expected_hash)? { if verify_hash(&model_path, expected_hash)? {
println!("OK"); debug!("Downloaded model hash OK");
} else { } else {
fs::remove_file(&model_path)?; fs::remove_file(&model_path)?;
return Err("Downloaded model hash verification failed".into()); return Err("Downloaded model hash verification failed".into());
@ -164,7 +169,8 @@ pub fn get_model_path(
/// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU). /// Create an ONNX Runtime session from a model path with CUDA (falls back to CPU).
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> { pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
println!("Loading model into ONNX Runtime with CUDA backend..."); info!(path = %model_path.display(), "Loading model into ONNX Runtime with CUDA backend");
let start = Instant::now();
let session = Session::builder()? let session = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])? .with_execution_providers([CUDAExecutionProvider::default().build()])?
@ -172,7 +178,10 @@ pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
.with_intra_threads(4)? .with_intra_threads(4)?
.commit_from_file(model_path)?; .commit_from_file(model_path)?;
println!("Model loaded successfully with CUDA!"); info!(
elapsed_ms = start.elapsed().as_millis() as u64,
"Model loaded successfully"
);
Ok(session) Ok(session)
} }

View file

@ -1,6 +1,7 @@
use { use {
image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType}, image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType},
std::error::Error, std::error::Error,
tracing::debug,
}; };
/// Compose `original` with `mask` as the alpha channel and return an RGBA image. /// Compose `original` with `mask` as the alpha channel and return an RGBA image.
@ -8,17 +9,18 @@ use {
/// The mask is expected to already be grayscale. If its dimensions differ from /// The mask is expected to already be grayscale. If its dimensions differ from
/// the original, it is resized with LANCZOS3. /// the original, it is resized with LANCZOS3.
pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage, Box<dyn Error>> { pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage, Box<dyn Error>> {
println!("Applying mask to remove background...");
let (orig_width, orig_height) = original.dimensions(); let (orig_width, orig_height) = original.dimensions();
let (mask_width, mask_height) = mask.dimensions(); let (mask_width, mask_height) = mask.dimensions();
println!("Mask dimensions: {}x{}", mask_width, mask_height);
println!("Original dimensions: {}x{}", orig_width, orig_height);
let resized_mask: std::borrow::Cow<'_, GrayImage> = let resized_mask: std::borrow::Cow<'_, GrayImage> =
if mask_width != orig_width || mask_height != orig_height { if mask_width != orig_width || mask_height != orig_height {
println!("Resizing mask to match original image..."); debug!(
mask_width,
mask_height,
orig_width,
orig_height,
"Resizing mask to match original image"
);
std::borrow::Cow::Owned(image::imageops::resize( std::borrow::Cow::Owned(image::imageops::resize(
mask, mask,
orig_width, orig_width,
@ -45,7 +47,7 @@ pub fn apply_mask(original: &DynamicImage, mask: &GrayImage) -> Result<RgbaImage
} }
} }
println!("Background removal complete!"); debug!("Mask applied to image");
Ok(result) Ok(result)
} }

View file

@ -2,7 +2,8 @@ use {
image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType}, image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType},
ndarray::{Array4, IntoDimension}, ndarray::{Array4, IntoDimension},
ort::{session::Session as OrtSession, value::Tensor}, ort::{session::Session as OrtSession, value::Tensor},
std::error::Error, std::{error::Error, time::Instant},
tracing::{debug, info},
}; };
use crate::model::Model; use crate::model::Model;
@ -105,13 +106,22 @@ pub trait Session {
/// 5. scale to `u8`, resize to original image dimensions /// 5. scale to `u8`, resize to original image dimensions
fn predict(&mut self, img: &DynamicImage) -> Result<GrayImage, Box<dyn Error>> { fn predict(&mut self, img: &DynamicImage) -> Result<GrayImage, Box<dyn Error>> {
let (orig_w, orig_h) = img.dimensions(); let (orig_w, orig_h) = img.dimensions();
let preprocess_start = Instant::now();
let input = self.normalize(img)?; let input = self.normalize(img)?;
debug!(elapsed_ms = preprocess_start.elapsed().as_millis() as u64, "Preprocessing complete");
let apply_sigmoid = self.apply_sigmoid(); let apply_sigmoid = self.apply_sigmoid();
let input_name = self.input_name().to_string(); let input_name = self.input_name().to_string();
debug!(%input_name, "Running ONNX session");
let infer_start = Instant::now();
let outputs = self let outputs = self
.inner() .inner()
.run(ort::inputs![input_name => Tensor::from_array(input)?])?; .run(ort::inputs![input_name => Tensor::from_array(input)?])?;
info!(
elapsed_ms = infer_start.elapsed().as_millis() as u64,
"ONNX inference complete"
);
let output = &outputs[0]; let output = &outputs[0];
let (shape, data) = output.try_extract_tensor::<f32>()?; let (shape, data) = output.try_extract_tensor::<f32>()?;
@ -177,6 +187,11 @@ pub trait Session {
} }
let mask = image::imageops::resize(&mask, orig_w, orig_h, FilterType::Lanczos3); let mask = image::imageops::resize(&mask, orig_w, orig_h, FilterType::Lanczos3);
debug!(
mask_width = orig_w,
mask_height = orig_h,
"Mask resized to original image dimensions"
);
Ok(mask) Ok(mask)
} }
} }
@ -191,7 +206,7 @@ pub fn init_session(
if !path.exists() { if !path.exists() {
return Err(format!("Custom model path does not exist: {}", custom_path).into()); return Err(format!("Custom model path does not exist: {}", custom_path).into());
} }
println!("Using custom model: {}", custom_path); info!(path = %custom_path, "Using custom model");
match model { match model {
Model::Bria => Box::new(BriaSession::from_model_path(&path)?), Model::Bria => Box::new(BriaSession::from_model_path(&path)?),
Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?), Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?),
@ -199,11 +214,11 @@ pub fn init_session(
} else { } else {
match model { match model {
Model::Bria => { Model::Bria => {
println!("Using model: bria-rmbg"); info!(model = "bria-rmbg", "Using model");
Box::new(BriaSession::new(offline)?) Box::new(BriaSession::new(offline)?)
} }
Model::BiRefNetLite => { Model::BiRefNetLite => {
println!("Using model: birefnet-general-lite"); info!(model = "birefnet-general-lite", "Using model");
Box::new(BiRefNetLiteSession::new(offline)?) Box::new(BiRefNetLiteSession::new(offline)?)
} }
} }