single and batch
This commit is contained in:
parent
395990e47d
commit
8d2fad9106
6 changed files with 786 additions and 483 deletions
1007
Cargo.lock
generated
1007
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -4,10 +4,11 @@
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
clap = { version = "4.5", features = ["derive"] }
|
clap = { version = "4", features = ["derive"] }
|
||||||
|
hex = "0.4"
|
||||||
image = "0.25"
|
image = "0.25"
|
||||||
ndarray = "0.17"
|
ndarray = "0.17"
|
||||||
ort = "=2.0.0-rc.11"
|
ort = "=2.0.0-rc.12"
|
||||||
reqwest = { version = "0.13", features = ["blocking"] }
|
reqwest = { version = "0.13", features = ["blocking"] }
|
||||||
sha2 = "0.10"
|
sha2 = "0.11"
|
||||||
show-image = { version = "0.14", features = ["image"] }
|
show-image = { version = "0.14", features = ["image"] }
|
||||||
|
|
|
||||||
171
src/main.rs
171
src/main.rs
|
|
@ -1,21 +1,50 @@
|
||||||
use clap::Parser;
|
use {
|
||||||
use image::GenericImageView;
|
clap::{Parser, Subcommand},
|
||||||
use show_image::{AsImageView, create_window, event};
|
image::{GenericImageView, ImageReader},
|
||||||
use std::error::Error;
|
remove_background::{
|
||||||
|
model::Model,
|
||||||
use remove_background::{
|
postprocessing::{apply_mask, create_side_by_side},
|
||||||
model::Model,
|
sessions::init_session,
|
||||||
postprocessing::{apply_mask, create_side_by_side},
|
},
|
||||||
sessions::{BiRefNetLiteSession, BriaSession, Session},
|
show_image::{AsImageView, create_window, event},
|
||||||
|
std::{
|
||||||
|
error::Error,
|
||||||
|
fs,
|
||||||
|
io::{Cursor, Read, Write, stdin, stdout},
|
||||||
|
path::PathBuf,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[derive(Subcommand)]
|
||||||
|
enum Command {
|
||||||
|
Single {
|
||||||
|
#[arg(long, help = "Run in debug mode, show images side by side")]
|
||||||
|
debug: bool,
|
||||||
|
|
||||||
|
#[arg(
|
||||||
|
value_name = "INPUT_FILE",
|
||||||
|
help = "File to process, if not provided, read from stdin"
|
||||||
|
)]
|
||||||
|
input_file: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(
|
||||||
|
value_name = "OUTPUT_FILE",
|
||||||
|
help = "File to write result to, if not provided, write to stdout"
|
||||||
|
)]
|
||||||
|
output_file: Option<PathBuf>,
|
||||||
|
},
|
||||||
|
Batch {
|
||||||
|
#[arg(value_name = "INPUT_DIRECTORY", help = "Directory to read images from")]
|
||||||
|
input_directory: PathBuf,
|
||||||
|
#[arg(value_name = "OUTPUT_DIRECTORY", help = "Directory to write images to")]
|
||||||
|
output_directory: PathBuf,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "remove_background")]
|
#[command(name = "remove_background")]
|
||||||
#[command(about = "Remove background from images using ONNX models", long_about = None)]
|
#[command(about = "Remove background from images using ONNX models", long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[arg(help = "URL of the image to download and process")]
|
|
||||||
url: String,
|
|
||||||
|
|
||||||
#[arg(
|
#[arg(
|
||||||
short,
|
short,
|
||||||
long,
|
long,
|
||||||
|
|
@ -31,66 +60,88 @@ struct Args {
|
||||||
|
|
||||||
#[arg(long, help = "Skip model download, fail if not cached")]
|
#[arg(long, help = "Skip model download, fail if not cached")]
|
||||||
offline: bool,
|
offline: bool,
|
||||||
|
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Command,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[show_image::main]
|
#[show_image::main]
|
||||||
fn main() -> Result<(), Box<dyn Error>> {
|
fn main() -> Result<(), Box<dyn Error>> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
println!("=== Background Removal Tool ===\n");
|
let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)?;
|
||||||
println!("Downloading image from: {}", args.url);
|
|
||||||
|
|
||||||
let client = reqwest::blocking::Client::builder()
|
match args.command {
|
||||||
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
Command::Single {
|
||||||
.build()?;
|
debug,
|
||||||
let response = client.get(&args.url).send()?;
|
input_file,
|
||||||
|
output_file,
|
||||||
|
} => {
|
||||||
|
let img = if let Some(input_file) = input_file {
|
||||||
|
ImageReader::open(input_file)?.decode()?
|
||||||
|
} else {
|
||||||
|
let mut bytes = Vec::new();
|
||||||
|
if stdin().lock().read_to_end(&mut bytes)? == 0 {
|
||||||
|
return Err("No input data provided via stdin".into());
|
||||||
|
}
|
||||||
|
ImageReader::new(Cursor::new(bytes))
|
||||||
|
.with_guessed_format()?
|
||||||
|
.decode()?
|
||||||
|
};
|
||||||
|
|
||||||
if !response.status().is_success() {
|
let (img_width, img_height) = img.dimensions();
|
||||||
return Err(format!("Failed to download image: HTTP {}", response.status()).into());
|
println!("Loaded image: {}x{}\n", img_width, img_height);
|
||||||
|
|
||||||
|
println!("Running inference...");
|
||||||
|
let mask = session.predict(&img)?;
|
||||||
|
println!(
|
||||||
|
"Inference complete! Mask dimensions: {}x{}",
|
||||||
|
mask.dimensions().0,
|
||||||
|
mask.dimensions().1
|
||||||
|
);
|
||||||
|
let result_rgba = apply_mask(&img, &mask)?;
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
debug_mode(&img, &result_rgba)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(output_file) = output_file {
|
||||||
|
result_rgba.save(output_file)?;
|
||||||
|
} else {
|
||||||
|
let mut buffer = Cursor::new(Vec::new());
|
||||||
|
result_rgba.write_to(&mut buffer, image::ImageFormat::Png)?;
|
||||||
|
let mut stdout = stdout().lock();
|
||||||
|
stdout.write_all(&buffer.into_inner())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Command::Batch {
|
||||||
|
input_directory,
|
||||||
|
output_directory,
|
||||||
|
} => {
|
||||||
|
fs::create_dir_all(&output_directory)?;
|
||||||
|
let files = fs::read_dir(input_directory)?;
|
||||||
|
for file in files {
|
||||||
|
let file = file?;
|
||||||
|
let path = file.path();
|
||||||
|
let img = ImageReader::open(&path)?.decode()?;
|
||||||
|
let mask = session.predict(&img)?;
|
||||||
|
let result_rgba = apply_mask(&img, &mask)?;
|
||||||
|
let output_path = output_directory.join(path.file_name().unwrap());
|
||||||
|
result_rgba.save(&output_path)?;
|
||||||
|
println!("Processed image saved to: {}", output_path.display());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let bytes = response.bytes()?;
|
Ok(())
|
||||||
println!("Downloaded {} bytes", bytes.len());
|
}
|
||||||
|
|
||||||
let img = image::load_from_memory(&bytes)?;
|
|
||||||
let (img_width, img_height) = img.dimensions();
|
|
||||||
println!("Loaded image: {}x{}\n", img_width, img_height);
|
|
||||||
|
|
||||||
let mut session: Box<dyn Session> = if let Some(custom_path) = args.model_path.as_deref() {
|
|
||||||
let path = std::path::PathBuf::from(custom_path);
|
|
||||||
if !path.exists() {
|
|
||||||
return Err(format!("Custom model path does not exist: {}", custom_path).into());
|
|
||||||
}
|
|
||||||
println!("Using custom model: {}", custom_path);
|
|
||||||
match args.model {
|
|
||||||
Model::Bria => Box::new(BriaSession::from_model_path(&path)?),
|
|
||||||
Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
match args.model {
|
|
||||||
Model::Bria => {
|
|
||||||
println!("Using model: bria-rmbg");
|
|
||||||
Box::new(BriaSession::new(args.offline)?)
|
|
||||||
}
|
|
||||||
Model::BiRefNetLite => {
|
|
||||||
println!("Using model: birefnet-general-lite");
|
|
||||||
Box::new(BiRefNetLiteSession::new(args.offline)?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("Running inference...");
|
|
||||||
let mask = session.predict(&img)?;
|
|
||||||
println!(
|
|
||||||
"Inference complete! Mask dimensions: {}x{}",
|
|
||||||
mask.dimensions().0,
|
|
||||||
mask.dimensions().1
|
|
||||||
);
|
|
||||||
|
|
||||||
let result_rgba = apply_mask(&img, &mask)?;
|
|
||||||
|
|
||||||
|
fn debug_mode(
|
||||||
|
img: &image::DynamicImage,
|
||||||
|
result_rgba: &image::RgbaImage,
|
||||||
|
) -> Result<(), Box<dyn Error>> {
|
||||||
println!("Creating side-by-side comparison...");
|
println!("Creating side-by-side comparison...");
|
||||||
let composite = create_side_by_side(&img, &result_rgba)?;
|
let composite = create_side_by_side(img, result_rgba)?;
|
||||||
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
|
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
|
||||||
|
|
||||||
let (comp_width, comp_height) = composite_dynamic.dimensions();
|
let (comp_width, comp_height) = composite_dynamic.dimensions();
|
||||||
|
|
|
||||||
37
src/model.rs
37
src/model.rs
|
|
@ -1,11 +1,17 @@
|
||||||
use clap::ValueEnum;
|
use {
|
||||||
use ort::execution_providers::CUDAExecutionProvider;
|
clap::ValueEnum,
|
||||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
ort::{
|
||||||
use sha2::{Digest, Sha256};
|
execution_providers::CUDAExecutionProvider,
|
||||||
use std::error::Error;
|
session::{Session, builder::GraphOptimizationLevel},
|
||||||
use std::fs;
|
},
|
||||||
use std::io::Write;
|
sha2::{Digest, Sha256},
|
||||||
use std::path::{Path, PathBuf};
|
std::{
|
||||||
|
error::Error,
|
||||||
|
fs,
|
||||||
|
io::{Read, Write},
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
/// CLI-facing model selector. Concrete session metadata (URL, checksum,
|
/// CLI-facing model selector. Concrete session metadata (URL, checksum,
|
||||||
/// preprocessing params) lives on the `Session` trait impls in `sessions/`.
|
/// preprocessing params) lives on the `Session` trait impls in `sessions/`.
|
||||||
|
|
@ -70,12 +76,19 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Error>> {
|
fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Error>> {
|
||||||
|
const BUF_SIZE: usize = 8192;
|
||||||
|
|
||||||
let mut file = fs::File::open(file_path)?;
|
let mut file = fs::File::open(file_path)?;
|
||||||
|
let mut buf = [0; BUF_SIZE];
|
||||||
let mut hasher = Sha256::new();
|
let mut hasher = Sha256::new();
|
||||||
std::io::copy(&mut file, &mut hasher)?;
|
loop {
|
||||||
let hash = hasher.finalize();
|
let bytes_read = file.read(&mut buf)?;
|
||||||
let hash_str = format!("{:x}", hash);
|
if bytes_read == 0 {
|
||||||
Ok(hash_str == expected_hash)
|
break;
|
||||||
|
}
|
||||||
|
hasher.update(&buf[..bytes_read]);
|
||||||
|
}
|
||||||
|
Ok(hex::encode(hasher.finalize()) == expected_hash)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve a model file to a local path, downloading + verifying if needed.
|
/// Resolve a model file to a local path, downloading + verifying if needed.
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
use image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType};
|
use {
|
||||||
use std::error::Error;
|
image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType},
|
||||||
|
std::error::Error,
|
||||||
|
};
|
||||||
|
|
||||||
/// Compose `original` with `mask` as the alpha channel and return an RGBA image.
|
/// Compose `original` with `mask` as the alpha channel and return an RGBA image.
|
||||||
///
|
///
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,11 @@
|
||||||
use image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType};
|
use {
|
||||||
use ndarray::{Array4, IntoDimension};
|
image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType},
|
||||||
use ort::{session::Session as OrtSession, value::Tensor};
|
ndarray::{Array4, IntoDimension},
|
||||||
use std::error::Error;
|
ort::{session::Session as OrtSession, value::Tensor},
|
||||||
|
std::error::Error,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::model::Model;
|
||||||
|
|
||||||
mod birefnet_lite;
|
mod birefnet_lite;
|
||||||
mod bria;
|
mod bria;
|
||||||
|
|
@ -176,3 +180,32 @@ pub trait Session {
|
||||||
Ok(mask)
|
Ok(mask)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn init_session(
|
||||||
|
custom_model_path: Option<&str>,
|
||||||
|
model: Model,
|
||||||
|
offline: bool,
|
||||||
|
) -> Result<Box<dyn Session>, Box<dyn Error>> {
|
||||||
|
Ok(if let Some(custom_path) = custom_model_path {
|
||||||
|
let path = std::path::PathBuf::from(custom_path);
|
||||||
|
if !path.exists() {
|
||||||
|
return Err(format!("Custom model path does not exist: {}", custom_path).into());
|
||||||
|
}
|
||||||
|
println!("Using custom model: {}", custom_path);
|
||||||
|
match model {
|
||||||
|
Model::Bria => Box::new(BriaSession::from_model_path(&path)?),
|
||||||
|
Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match model {
|
||||||
|
Model::Bria => {
|
||||||
|
println!("Using model: bria-rmbg");
|
||||||
|
Box::new(BriaSession::new(offline)?)
|
||||||
|
}
|
||||||
|
Model::BiRefNetLite => {
|
||||||
|
println!("Using model: birefnet-general-lite");
|
||||||
|
Box::new(BiRefNetLiteSession::new(offline)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue