single and batch
This commit is contained in:
parent
395990e47d
commit
8d2fad9106
6 changed files with 786 additions and 483 deletions
1007
Cargo.lock
generated
1007
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -4,10 +4,11 @@
|
|||
version = "0.1.0"
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
hex = "0.4"
|
||||
image = "0.25"
|
||||
ndarray = "0.17"
|
||||
ort = "=2.0.0-rc.11"
|
||||
ort = "=2.0.0-rc.12"
|
||||
reqwest = { version = "0.13", features = ["blocking"] }
|
||||
sha2 = "0.10"
|
||||
sha2 = "0.11"
|
||||
show-image = { version = "0.14", features = ["image"] }
|
||||
|
|
|
|||
147
src/main.rs
147
src/main.rs
|
|
@ -1,21 +1,50 @@
|
|||
use clap::Parser;
|
||||
use image::GenericImageView;
|
||||
use show_image::{AsImageView, create_window, event};
|
||||
use std::error::Error;
|
||||
|
||||
use remove_background::{
|
||||
use {
|
||||
clap::{Parser, Subcommand},
|
||||
image::{GenericImageView, ImageReader},
|
||||
remove_background::{
|
||||
model::Model,
|
||||
postprocessing::{apply_mask, create_side_by_side},
|
||||
sessions::{BiRefNetLiteSession, BriaSession, Session},
|
||||
sessions::init_session,
|
||||
},
|
||||
show_image::{AsImageView, create_window, event},
|
||||
std::{
|
||||
error::Error,
|
||||
fs,
|
||||
io::{Cursor, Read, Write, stdin, stdout},
|
||||
path::PathBuf,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Command {
|
||||
Single {
|
||||
#[arg(long, help = "Run in debug mode, show images side by side")]
|
||||
debug: bool,
|
||||
|
||||
#[arg(
|
||||
value_name = "INPUT_FILE",
|
||||
help = "File to process, if not provided, read from stdin"
|
||||
)]
|
||||
input_file: Option<PathBuf>,
|
||||
|
||||
#[arg(
|
||||
value_name = "OUTPUT_FILE",
|
||||
help = "File to write result to, if not provided, write to stdout"
|
||||
)]
|
||||
output_file: Option<PathBuf>,
|
||||
},
|
||||
Batch {
|
||||
#[arg(value_name = "INPUT_DIRECTORY", help = "Directory to read images from")]
|
||||
input_directory: PathBuf,
|
||||
#[arg(value_name = "OUTPUT_DIRECTORY", help = "Directory to write images to")]
|
||||
output_directory: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "remove_background")]
|
||||
#[command(about = "Remove background from images using ONNX models", long_about = None)]
|
||||
struct Args {
|
||||
#[arg(help = "URL of the image to download and process")]
|
||||
url: String,
|
||||
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
|
|
@ -31,54 +60,38 @@ struct Args {
|
|||
|
||||
#[arg(long, help = "Skip model download, fail if not cached")]
|
||||
offline: bool,
|
||||
|
||||
#[command(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
#[show_image::main]
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
let args = Args::parse();
|
||||
|
||||
println!("=== Background Removal Tool ===\n");
|
||||
println!("Downloading image from: {}", args.url);
|
||||
let mut session = init_session(args.model_path.as_deref(), args.model, args.offline)?;
|
||||
|
||||
let client = reqwest::blocking::Client::builder()
|
||||
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
||||
.build()?;
|
||||
let response = client.get(&args.url).send()?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(format!("Failed to download image: HTTP {}", response.status()).into());
|
||||
match args.command {
|
||||
Command::Single {
|
||||
debug,
|
||||
input_file,
|
||||
output_file,
|
||||
} => {
|
||||
let img = if let Some(input_file) = input_file {
|
||||
ImageReader::open(input_file)?.decode()?
|
||||
} else {
|
||||
let mut bytes = Vec::new();
|
||||
if stdin().lock().read_to_end(&mut bytes)? == 0 {
|
||||
return Err("No input data provided via stdin".into());
|
||||
}
|
||||
ImageReader::new(Cursor::new(bytes))
|
||||
.with_guessed_format()?
|
||||
.decode()?
|
||||
};
|
||||
|
||||
let bytes = response.bytes()?;
|
||||
println!("Downloaded {} bytes", bytes.len());
|
||||
|
||||
let img = image::load_from_memory(&bytes)?;
|
||||
let (img_width, img_height) = img.dimensions();
|
||||
println!("Loaded image: {}x{}\n", img_width, img_height);
|
||||
|
||||
let mut session: Box<dyn Session> = if let Some(custom_path) = args.model_path.as_deref() {
|
||||
let path = std::path::PathBuf::from(custom_path);
|
||||
if !path.exists() {
|
||||
return Err(format!("Custom model path does not exist: {}", custom_path).into());
|
||||
}
|
||||
println!("Using custom model: {}", custom_path);
|
||||
match args.model {
|
||||
Model::Bria => Box::new(BriaSession::from_model_path(&path)?),
|
||||
Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?),
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
Model::Bria => {
|
||||
println!("Using model: bria-rmbg");
|
||||
Box::new(BriaSession::new(args.offline)?)
|
||||
}
|
||||
Model::BiRefNetLite => {
|
||||
println!("Using model: birefnet-general-lite");
|
||||
Box::new(BiRefNetLiteSession::new(args.offline)?)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
println!("Running inference...");
|
||||
let mask = session.predict(&img)?;
|
||||
println!(
|
||||
|
|
@ -86,11 +99,49 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||
mask.dimensions().0,
|
||||
mask.dimensions().1
|
||||
);
|
||||
|
||||
let result_rgba = apply_mask(&img, &mask)?;
|
||||
|
||||
if debug {
|
||||
debug_mode(&img, &result_rgba)?;
|
||||
}
|
||||
|
||||
if let Some(output_file) = output_file {
|
||||
result_rgba.save(output_file)?;
|
||||
} else {
|
||||
let mut buffer = Cursor::new(Vec::new());
|
||||
result_rgba.write_to(&mut buffer, image::ImageFormat::Png)?;
|
||||
let mut stdout = stdout().lock();
|
||||
stdout.write_all(&buffer.into_inner())?;
|
||||
}
|
||||
}
|
||||
Command::Batch {
|
||||
input_directory,
|
||||
output_directory,
|
||||
} => {
|
||||
fs::create_dir_all(&output_directory)?;
|
||||
let files = fs::read_dir(input_directory)?;
|
||||
for file in files {
|
||||
let file = file?;
|
||||
let path = file.path();
|
||||
let img = ImageReader::open(&path)?.decode()?;
|
||||
let mask = session.predict(&img)?;
|
||||
let result_rgba = apply_mask(&img, &mask)?;
|
||||
let output_path = output_directory.join(path.file_name().unwrap());
|
||||
result_rgba.save(&output_path)?;
|
||||
println!("Processed image saved to: {}", output_path.display());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn debug_mode(
|
||||
img: &image::DynamicImage,
|
||||
result_rgba: &image::RgbaImage,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
println!("Creating side-by-side comparison...");
|
||||
let composite = create_side_by_side(&img, &result_rgba)?;
|
||||
let composite = create_side_by_side(img, result_rgba)?;
|
||||
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
|
||||
|
||||
let (comp_width, comp_height) = composite_dynamic.dimensions();
|
||||
|
|
|
|||
37
src/model.rs
37
src/model.rs
|
|
@ -1,11 +1,17 @@
|
|||
use clap::ValueEnum;
|
||||
use ort::execution_providers::CUDAExecutionProvider;
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::error::Error;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use {
|
||||
clap::ValueEnum,
|
||||
ort::{
|
||||
execution_providers::CUDAExecutionProvider,
|
||||
session::{Session, builder::GraphOptimizationLevel},
|
||||
},
|
||||
sha2::{Digest, Sha256},
|
||||
std::{
|
||||
error::Error,
|
||||
fs,
|
||||
io::{Read, Write},
|
||||
path::{Path, PathBuf},
|
||||
},
|
||||
};
|
||||
|
||||
/// CLI-facing model selector. Concrete session metadata (URL, checksum,
|
||||
/// preprocessing params) lives on the `Session` trait impls in `sessions/`.
|
||||
|
|
@ -70,12 +76,19 @@ fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
|
|||
}
|
||||
|
||||
fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Error>> {
|
||||
const BUF_SIZE: usize = 8192;
|
||||
|
||||
let mut file = fs::File::open(file_path)?;
|
||||
let mut buf = [0; BUF_SIZE];
|
||||
let mut hasher = Sha256::new();
|
||||
std::io::copy(&mut file, &mut hasher)?;
|
||||
let hash = hasher.finalize();
|
||||
let hash_str = format!("{:x}", hash);
|
||||
Ok(hash_str == expected_hash)
|
||||
loop {
|
||||
let bytes_read = file.read(&mut buf)?;
|
||||
if bytes_read == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..bytes_read]);
|
||||
}
|
||||
Ok(hex::encode(hasher.finalize()) == expected_hash)
|
||||
}
|
||||
|
||||
/// Resolve a model file to a local path, downloading + verifying if needed.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
use image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType};
|
||||
use std::error::Error;
|
||||
use {
|
||||
image::{DynamicImage, GenericImageView, GrayImage, Rgba, RgbaImage, imageops::FilterType},
|
||||
std::error::Error,
|
||||
};
|
||||
|
||||
/// Compose `original` with `mask` as the alpha channel and return an RGBA image.
|
||||
///
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
use image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType};
|
||||
use ndarray::{Array4, IntoDimension};
|
||||
use ort::{session::Session as OrtSession, value::Tensor};
|
||||
use std::error::Error;
|
||||
use {
|
||||
image::{DynamicImage, GenericImageView, GrayImage, Luma, imageops::FilterType},
|
||||
ndarray::{Array4, IntoDimension},
|
||||
ort::{session::Session as OrtSession, value::Tensor},
|
||||
std::error::Error,
|
||||
};
|
||||
|
||||
use crate::model::Model;
|
||||
|
||||
mod birefnet_lite;
|
||||
mod bria;
|
||||
|
|
@ -176,3 +180,32 @@ pub trait Session {
|
|||
Ok(mask)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_session(
|
||||
custom_model_path: Option<&str>,
|
||||
model: Model,
|
||||
offline: bool,
|
||||
) -> Result<Box<dyn Session>, Box<dyn Error>> {
|
||||
Ok(if let Some(custom_path) = custom_model_path {
|
||||
let path = std::path::PathBuf::from(custom_path);
|
||||
if !path.exists() {
|
||||
return Err(format!("Custom model path does not exist: {}", custom_path).into());
|
||||
}
|
||||
println!("Using custom model: {}", custom_path);
|
||||
match model {
|
||||
Model::Bria => Box::new(BriaSession::from_model_path(&path)?),
|
||||
Model::BiRefNetLite => Box::new(BiRefNetLiteSession::from_model_path(&path)?),
|
||||
}
|
||||
} else {
|
||||
match model {
|
||||
Model::Bria => {
|
||||
println!("Using model: bria-rmbg");
|
||||
Box::new(BriaSession::new(offline)?)
|
||||
}
|
||||
Model::BiRefNetLite => {
|
||||
println!("Using model: birefnet-general-lite");
|
||||
Box::new(BiRefNetLiteSession::new(offline)?)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue