fmt
This commit is contained in:
parent
296a0172f1
commit
86e0dd734d
4 changed files with 85 additions and 80 deletions
20
src/main.rs
20
src/main.rs
|
|
@ -1,5 +1,6 @@
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use image::GenericImageView;
|
use image::GenericImageView;
|
||||||
|
use ndarray::IntoDimension;
|
||||||
use ort::value::Tensor;
|
use ort::value::Tensor;
|
||||||
use show_image::{AsImageView, create_window, event};
|
use show_image::{AsImageView, create_window, event};
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
|
@ -73,7 +74,18 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
let mask_output = &outputs[0];
|
let mask_output = &outputs[0];
|
||||||
let (mask_shape, mask_array) = mask_output.try_extract_tensor::<f32>()?; // Returns (shape, &[f32])
|
let (mask_shape, mask_array) = mask_output.try_extract_tensor::<f32>()?; // Returns (shape, &[f32])
|
||||||
|
|
||||||
/*let mask_tensor = mask_array.into_dimensionality::<ndarray::Ix4>()?.to_owned();
|
// Convert the slice to Array4
|
||||||
|
let mask_tensor = ndarray::ArrayView::from_shape(
|
||||||
|
(
|
||||||
|
mask_shape[0] as usize,
|
||||||
|
mask_shape[1] as usize,
|
||||||
|
mask_shape[2] as usize,
|
||||||
|
mask_shape[3] as usize,
|
||||||
|
)
|
||||||
|
.into_dimension(),
|
||||||
|
mask_array,
|
||||||
|
)?
|
||||||
|
.to_owned();
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"Inference complete! Output shape: {:?}",
|
"Inference complete! Output shape: {:?}",
|
||||||
|
|
@ -108,11 +120,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
);
|
);
|
||||||
println!(" Left: Original image");
|
println!(" Left: Original image");
|
||||||
println!(" Right: Background removed (shown on checkered background)");
|
println!(" Right: Background removed (shown on checkered background)");
|
||||||
println!("\nPress ESC to close the window."); */
|
println!("\nPress ESC to close the window.");
|
||||||
let window = create_window(
|
|
||||||
"Background Removal - Original (Left) vs Result (Right) - Press ESC to close",
|
|
||||||
Default::default(),
|
|
||||||
)?;
|
|
||||||
// Event loop - wait for ESC key to close
|
// Event loop - wait for ESC key to close
|
||||||
for event in window.event_channel()? {
|
for event in window.event_channel()? {
|
||||||
if let event::WindowEvent::KeyboardInput(event) = event {
|
if let event::WindowEvent::KeyboardInput(event) = event {
|
||||||
|
|
|
||||||
25
src/model.rs
25
src/model.rs
|
|
@ -1,5 +1,5 @@
|
||||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||||
use sha2::{Sha256, Digest};
|
use sha2::{Digest, Sha256};
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
@ -17,7 +17,7 @@ impl ModelInfo {
|
||||||
/// BiRefNet General Lite model
|
/// BiRefNet General Lite model
|
||||||
pub const BIREFNET_LITE: ModelInfo = ModelInfo {
|
pub const BIREFNET_LITE: ModelInfo = ModelInfo {
|
||||||
name: "birefnet-general-lite",
|
name: "birefnet-general-lite",
|
||||||
url: "https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/BiRefNet-general-lite.onnx",
|
url: "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png",
|
||||||
sha256: None, // We'll skip verification for now
|
sha256: None, // We'll skip verification for now
|
||||||
input_size: (1024, 1024),
|
input_size: (1024, 1024),
|
||||||
};
|
};
|
||||||
|
|
@ -25,9 +25,11 @@ impl ModelInfo {
|
||||||
|
|
||||||
/// Get the cache directory for models
|
/// Get the cache directory for models
|
||||||
fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
|
fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
|
||||||
let home = std::env::var("HOME")
|
let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"))?;
|
||||||
.or_else(|_| std::env::var("USERPROFILE"))?;
|
let cache_dir = Path::new(&home)
|
||||||
let cache_dir = Path::new(&home).join(".cache").join("remove_background").join("models");
|
.join(".cache")
|
||||||
|
.join("remove_background")
|
||||||
|
.join("models");
|
||||||
fs::create_dir_all(&cache_dir)?;
|
fs::create_dir_all(&cache_dir)?;
|
||||||
Ok(cache_dir)
|
Ok(cache_dir)
|
||||||
}
|
}
|
||||||
|
|
@ -86,7 +88,11 @@ fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Er
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get or download model, return path to model file
|
/// Get or download model, return path to model file
|
||||||
pub fn get_model_path(model_info: &ModelInfo, custom_path: Option<&str>, offline: bool) -> Result<PathBuf, Box<dyn Error>> {
|
pub fn get_model_path(
|
||||||
|
model_info: &ModelInfo,
|
||||||
|
custom_path: Option<&str>,
|
||||||
|
offline: bool,
|
||||||
|
) -> Result<PathBuf, Box<dyn Error>> {
|
||||||
// If custom path provided, use it
|
// If custom path provided, use it
|
||||||
if let Some(path) = custom_path {
|
if let Some(path) = custom_path {
|
||||||
let model_path = PathBuf::from(path);
|
let model_path = PathBuf::from(path);
|
||||||
|
|
@ -113,7 +119,9 @@ pub fn get_model_path(model_info: &ModelInfo, custom_path: Option<&str>, offline
|
||||||
println!("OK");
|
println!("OK");
|
||||||
} else {
|
} else {
|
||||||
println!("FAILED");
|
println!("FAILED");
|
||||||
return Err("Model hash verification failed. Try deleting the cached model.".into());
|
return Err(
|
||||||
|
"Model hash verification failed. Try deleting the cached model.".into(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -125,7 +133,8 @@ pub fn get_model_path(model_info: &ModelInfo, custom_path: Option<&str>, offline
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"Model not found in cache and offline mode is enabled. Cache path: {}",
|
"Model not found in cache and offline mode is enabled. Cache path: {}",
|
||||||
model_path.display()
|
model_path.display()
|
||||||
).into());
|
)
|
||||||
|
.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Model not found in cache, downloading...");
|
println!("Model not found in cache, downloading...");
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use image::{DynamicImage, RgbaImage, Rgba, GenericImageView, imageops::FilterType};
|
use image::{DynamicImage, GenericImageView, Rgba, RgbaImage, imageops::FilterType};
|
||||||
use ndarray::Array4;
|
use ndarray::Array4;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
|
||||||
|
|
@ -19,10 +19,7 @@ pub fn apply_mask(
|
||||||
// Extract mask dimensions
|
// Extract mask dimensions
|
||||||
let mask_shape = mask_tensor.shape();
|
let mask_shape = mask_tensor.shape();
|
||||||
if mask_shape[0] != 1 || mask_shape[1] != 1 {
|
if mask_shape[0] != 1 || mask_shape[1] != 1 {
|
||||||
return Err(format!(
|
return Err(format!("Expected mask shape [1, 1, H, W], got {:?}", mask_shape).into());
|
||||||
"Expected mask shape [1, 1, H, W], got {:?}",
|
|
||||||
mask_shape
|
|
||||||
).into());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mask_height = mask_shape[2] as u32;
|
let mask_height = mask_shape[2] as u32;
|
||||||
|
|
@ -45,12 +42,7 @@ pub fn apply_mask(
|
||||||
// Resize mask to match original image dimensions if needed
|
// Resize mask to match original image dimensions if needed
|
||||||
let resized_mask = if mask_width != orig_width || mask_height != orig_height {
|
let resized_mask = if mask_width != orig_width || mask_height != orig_height {
|
||||||
println!("Resizing mask to match original image...");
|
println!("Resizing mask to match original image...");
|
||||||
image::imageops::resize(
|
image::imageops::resize(&mask_image, orig_width, orig_height, FilterType::Lanczos3)
|
||||||
&mask_image,
|
|
||||||
orig_width,
|
|
||||||
orig_height,
|
|
||||||
FilterType::Lanczos3,
|
|
||||||
)
|
|
||||||
} else {
|
} else {
|
||||||
mask_image
|
mask_image
|
||||||
};
|
};
|
||||||
|
|
@ -115,11 +107,7 @@ pub fn create_side_by_side(
|
||||||
let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
|
let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
|
||||||
let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
|
let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
|
||||||
|
|
||||||
composite.put_pixel(
|
composite.put_pixel(x + width, y, Rgba([final_r, final_g, final_b, 255]));
|
||||||
x + width,
|
|
||||||
y,
|
|
||||||
Rgba([final_r, final_g, final_b, 255]),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue