compiles so far

This commit is contained in:
Matthew Deville 2026-01-22 00:25:11 +01:00
parent 04054560d1
commit 296a0172f1
8 changed files with 930 additions and 12 deletions

371
Cargo.lock generated
View file

@ -303,6 +303,12 @@ version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "base64ct"
version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06"
[[package]]
name = "bit-set"
version = "0.5.3"
@ -351,6 +357,15 @@ version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a"
[[package]]
name = "block-buffer"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
dependencies = [
"generic-array",
]
[[package]]
name = "block-sys"
version = "0.1.0-beta.1"
@ -388,6 +403,12 @@ version = "1.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4"
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "byteorder-lite"
version = "0.1.0"
@ -596,6 +617,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "cpufeatures"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
dependencies = [
"libc",
]
[[package]]
name = "crc32fast"
version = "1.5.0"
@ -636,6 +666,16 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5"
[[package]]
name = "crypto-common"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a"
dependencies = [
"generic-array",
"typenum",
]
[[package]]
name = "d3d12"
version = "0.7.0"
@ -647,6 +687,26 @@ dependencies = [
"winapi",
]
[[package]]
name = "der"
version = "0.7.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb"
dependencies = [
"pem-rfc7468",
"zeroize",
]
[[package]]
name = "digest"
version = "0.10.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
]
[[package]]
name = "dispatch"
version = "0.2.0"
@ -726,6 +786,16 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "errno"
version = "0.3.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.61.2",
]
[[package]]
name = "exr"
version = "1.74.0"
@ -741,6 +811,12 @@ dependencies = [
"zune-inflate",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "fax"
version = "0.2.6"
@ -932,6 +1008,16 @@ dependencies = [
"slab",
]
[[package]]
name = "generic-array"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "getrandom"
version = "0.2.17"
@ -1124,6 +1210,12 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df"
[[package]]
name = "hmac-sha256"
version = "1.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad6880c8d4a9ebf39c6e8b77007ce223f646a4d21ce29d99f70cb16420545425"
[[package]]
name = "http"
version = "1.4.0"
@ -1561,6 +1653,12 @@ dependencies = [
"redox_syscall 0.7.0",
]
[[package]]
name = "linux-raw-sys"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
[[package]]
name = "litemap"
version = "0.8.1"
@ -1597,6 +1695,12 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
[[package]]
name = "lzma-rust2"
version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1670343e58806300d87950e3401e820b519b9384281bbabfb15e3636689ffd69"
[[package]]
name = "malloc_buf"
version = "0.0.6"
@ -1606,6 +1710,16 @@ dependencies = [
"libc",
]
[[package]]
name = "matrixmultiply"
version = "0.3.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]]
name = "maybe-rayon"
version = "0.1.1"
@ -1725,6 +1839,38 @@ dependencies = [
"unicode-xid",
]
[[package]]
name = "native-tls"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe 0.1.6",
"openssl-sys",
"schannel",
"security-framework 2.11.1",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "ndarray"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
[[package]]
name = "ndk"
version = "0.7.0"
@ -1810,6 +1956,15 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-derive"
version = "0.4.2"
@ -1958,12 +2113,56 @@ version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
[[package]]
name = "openssl"
version = "0.10.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
dependencies = [
"bitflags 2.10.0",
"cfg-if",
"foreign-types 0.3.2",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.114",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-probe"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
[[package]]
name = "openssl-sys"
version = "0.9.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "orbclient"
version = "0.3.50"
@ -1974,6 +2173,30 @@ dependencies = [
"libredox",
]
[[package]]
name = "ort"
version = "2.0.0-rc.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5df903c0d2c07b56950f1058104ab0c8557159f2741782223704de9be73c3c"
dependencies = [
"ndarray",
"ort-sys",
"smallvec",
"tracing",
"ureq",
]
[[package]]
name = "ort-sys"
version = "2.0.0-rc.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06503bb33f294c5f1ba484011e053bfa6ae227074bdb841e9863492dc5960d4b"
dependencies = [
"hmac-sha256",
"lzma-rust2",
"ureq",
]
[[package]]
name = "owned_ttf_parser"
version = "0.25.1"
@ -2018,6 +2241,15 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec"
[[package]]
name = "pem-rfc7468"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412"
dependencies = [
"base64ct",
]
[[package]]
name = "percent-encoding"
version = "2.3.2"
@ -2078,6 +2310,21 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "portable-atomic"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950"
[[package]]
name = "portable-atomic-util"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507"
dependencies = [
"portable-atomic",
]
[[package]]
name = "potential_utf"
version = "0.1.4"
@ -2320,6 +2567,12 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9"
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.11.0"
@ -2373,7 +2626,10 @@ version = "0.1.0"
dependencies = [
"clap",
"image",
"ndarray",
"ort",
"reqwest",
"sha2",
"show-image",
]
@ -2470,6 +2726,19 @@ dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34"
dependencies = [
"bitflags 2.10.0",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.61.2",
]
[[package]]
name = "rustls"
version = "0.23.36"
@ -2490,10 +2759,10 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63"
dependencies = [
"openssl-probe",
"openssl-probe 0.2.1",
"rustls-pki-types",
"schannel",
"security-framework",
"security-framework 3.5.1",
]
[[package]]
@ -2521,7 +2790,7 @@ dependencies = [
"rustls-native-certs",
"rustls-platform-verifier-android",
"rustls-webpki",
"security-framework",
"security-framework 3.5.1",
"security-framework-sys",
"webpki-root-certs",
"windows-sys 0.61.2",
@ -2594,6 +2863,19 @@ dependencies = [
"tiny-skia",
]
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.10.0",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework"
version = "3.5.1"
@ -2653,6 +2935,17 @@ dependencies = [
"syn 2.0.114",
]
[[package]]
name = "sha2"
version = "0.10.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "shlex"
version = "1.3.0"
@ -2751,6 +3044,17 @@ dependencies = [
"windows-sys 0.60.2",
]
[[package]]
name = "socks"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b"
dependencies = [
"byteorder",
"libc",
"winapi",
]
[[package]]
name = "spirv"
version = "0.2.0+1.5.4"
@ -2854,6 +3158,19 @@ dependencies = [
"libc",
]
[[package]]
name = "tempfile"
version = "3.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c"
dependencies = [
"fastrand",
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys 0.61.2",
]
[[package]]
name = "termcolor"
version = "1.4.1"
@ -3097,6 +3414,12 @@ version = "0.25.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2df906b07856748fa3f6e0ad0cbaa047052d4a7dd609e231c4f72cee8c36f31"
[[package]]
name = "typenum"
version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb"
[[package]]
name = "unicode-ident"
version = "1.0.22"
@ -3121,6 +3444,36 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "ureq"
version = "3.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a"
dependencies = [
"base64",
"der",
"log",
"native-tls",
"percent-encoding",
"rustls-pki-types",
"socks",
"ureq-proto",
"utf-8",
"webpki-root-certs",
]
[[package]]
name = "ureq-proto"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f"
dependencies = [
"base64",
"http",
"httparse",
"log",
]
[[package]]
name = "url"
version = "2.5.8"
@ -3133,6 +3486,12 @@ dependencies = [
"serde",
]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "utf8_iter"
version = "1.0.4"
@ -3156,6 +3515,12 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "vec_map"
version = "0.8.2"

View file

@ -6,5 +6,8 @@
[dependencies]
clap = { version = "4.5", features = ["derive"] }
image = "0.25"
ndarray = "0.17"
ort = "=2.0.0-rc.11"
reqwest = { version = "0.13", features = ["blocking"] }
sha2 = "0.10"
show-image = { version = "0.14", features = ["image"] }

103
README.md Normal file
View file

@ -0,0 +1,103 @@
# Background Removal Tool
A Rust-based background removal tool using ONNX models (BiRefNet) with GPU support.
## Features
- Download images from URLs
- Automatic ONNX model download and caching
- Background removal using BiRefNet-general-lite model
- Side-by-side comparison display (original vs processed)
- Checkered background visualization for transparency
- CLI with flexible options
## Building
### With Nix (Recommended)
```bash
# Enter the development environment
nix develop
# Build the project
cargo build --release
```
### Without Nix
If you don't use Nix, you'll need to ensure you have the following system dependencies:
- Rust toolchain
- OpenSSL development libraries (`libssl-dev` on Ubuntu/Debian)
- pkg-config
- X11 libraries (for window display)
- Vulkan loader
## Usage
Basic usage:
```bash
# Download and process an image
cargo run -- https://example.com/image.jpg
# Use custom model
cargo run -- https://example.com/image.jpg --model-path ./custom_model.onnx
# Offline mode (use only cached model)
cargo run -- https://example.com/image.jpg --offline
```
## How It Works
1. **Download**: Fetches the image from the provided URL
2. **Model Loading**: Auto-downloads BiRefNet model (cached in `~/.cache/remove_background/models/`)
3. **Preprocessing**: Resizes image to 1024x1024, normalizes pixels, converts to CHW format
4. **Inference**: Runs ONNX model to generate foreground mask
5. **Postprocessing**: Applies mask to original image, creates RGBA output
6. **Display**: Shows side-by-side comparison with checkered background
## Project Structure
```
src/
├── main.rs - CLI interface and orchestration
├── model.rs - Model download and ONNX session management
├── preprocessing.rs - Image preprocessing (resize, normalize, CHW conversion)
└── postprocessing.rs - Mask application and side-by-side display creation
```
## Dependencies
- **ort**: ONNX Runtime Rust bindings
- **ndarray**: N-dimensional array operations
- **image**: Image loading and manipulation
- **reqwest**: HTTP client for downloads
- **show-image**: Cross-platform window display
- **clap**: Command-line argument parsing
- **sha2**: Hash verification
## Future Enhancements
- [ ] CUDA execution provider support (GPU acceleration)
- [ ] Batch processing support
- [ ] Additional model options (U2-Net, ISNet, etc.)
- [ ] Save output to file
- [ ] Local file input support
- [ ] Progress bars for downloads
- [ ] WebAssembly support
## Controls
- **ESC**: Close the window and exit
## Model Information
Default model: **BiRefNet-general-lite**
- Size: ~50MB
- Input: 1024x1024 RGB
- Output: 1024x1024 mask
- Source: https://huggingface.co/ZhengPeng7/BiRefNet
## License
See LICENSE file for details.

View file

@ -44,6 +44,7 @@
# Graphics/rendering libraries needed for wgpu (used by show-image)
graphicsBuildInputs = with pkgs; [
vulkan-loader
openssl
];
buildInputs = xorgBuildInputs ++ waylandBuildInputs ++ graphicsBuildInputs;

View file

@ -1,14 +1,29 @@
use clap::Parser;
use image::GenericImageView;
use ort::value::Tensor;
use show_image::{AsImageView, create_window, event};
use std::error::Error;
mod model;
mod postprocessing;
mod preprocessing;
use model::{ModelInfo, create_session, get_model_path};
use postprocessing::{apply_mask, create_side_by_side};
use preprocessing::preprocess_image;
#[derive(Parser)]
#[command(name = "remove_background")]
#[command(about = "Download and display an image (future: background removal)", long_about = None)]
#[command(about = "Remove background from images using ONNX models", long_about = None)]
struct Args {
#[arg(help = "URL of the image to download and display")]
#[arg(help = "URL of the image to download and process")]
url: String,
#[arg(long, help = "Path to custom ONNX model")]
model_path: Option<String>,
#[arg(long, help = "Skip model download, fail if not cached")]
offline: bool,
}
#[show_image::main]
@ -16,6 +31,7 @@ fn main() -> Result<(), Box<dyn Error>> {
// Parse command line arguments
let args = Args::parse();
println!("=== Background Removal Tool ===\n");
println!("Downloading image from: {}", args.url);
// Download the image with a user agent (using blocking client)
@ -34,17 +50,69 @@ fn main() -> Result<(), Box<dyn Error>> {
// Load the image
let img = image::load_from_memory(&bytes)?;
let (width, height) = img.dimensions();
println!("Loaded image: {}x{}", width, height);
println!("Loaded image: {}x{}\n", width, height);
// Create a window and display the image.
let window = create_window("image", Default::default())?;
// Get model info
let model_info = ModelInfo::BIREFNET_LITE;
println!("Using model: {}", model_info.name);
// Get or download model
let model_path = get_model_path(&model_info, args.model_path.as_deref(), args.offline)?;
// Create ONNX Runtime session
let mut session = create_session(&model_path)?;
// Preprocess image
let input_tensor = preprocess_image(&img, model_info.input_size.0, model_info.input_size.1)?;
// Run inference
println!("Running inference...");
let outputs = session.run(ort::inputs!["input" => Tensor::from_array(input_tensor)?])?;
// Extract mask output
let mask_output = &outputs[0];
let (mask_shape, mask_array) = mask_output.try_extract_tensor::<f32>()?; // Returns (shape, &[f32])
/*let mask_tensor = mask_array.into_dimensionality::<ndarray::Ix4>()?.to_owned();
println!(
"Inference complete! Output shape: {:?}",
mask_tensor.shape()
);
// Apply mask to remove background
let result_rgba = apply_mask(&img, mask_tensor)?;
// Create side-by-side comparison
println!("Creating side-by-side comparison...");
let composite = create_side_by_side(&img, &result_rgba)?;
let composite_dynamic = image::DynamicImage::ImageRgba8(composite);
// Display the result
let (comp_width, comp_height) = composite_dynamic.dimensions();
let window = create_window(
"Background Removal - Original (Left) vs Result (Right) - Press ESC to close",
Default::default(),
)?;
window.set_image(
"downloaded image",
&img.as_image_view().map_err(|e| e.to_string())?,
"comparison",
&composite_dynamic
.as_image_view()
.map_err(|e| e.to_string())?,
)?;
println!("Image displayed. Press ESC to close the window.");
println!("\n=== Done! ===");
println!(
"Displaying side-by-side comparison ({}x{}):",
comp_width, comp_height
);
println!(" Left: Original image");
println!(" Right: Background removed (shown on checkered background)");
println!("\nPress ESC to close the window."); */
let window = create_window(
"Background Removal - Original (Left) vs Result (Right) - Press ESC to close",
Default::default(),
)?;
// Event loop - wait for ESC key to close
for event in window.event_channel()? {
if let event::WindowEvent::KeyboardInput(event) = event {

162
src/model.rs Normal file
View file

@ -0,0 +1,162 @@
use ort::session::{Session, builder::GraphOptimizationLevel};
use sha2::{Sha256, Digest};
use std::error::Error;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
/// Model metadata
pub struct ModelInfo {
pub name: &'static str,
pub url: &'static str,
pub sha256: Option<&'static str>,
pub input_size: (u32, u32),
}
impl ModelInfo {
/// BiRefNet General Lite model
pub const BIREFNET_LITE: ModelInfo = ModelInfo {
name: "birefnet-general-lite",
url: "https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/BiRefNet-general-lite.onnx",
sha256: None, // We'll skip verification for now
input_size: (1024, 1024),
};
}
/// Get the cache directory for models
fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
let home = std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))?;
let cache_dir = Path::new(&home).join(".cache").join("remove_background").join("models");
fs::create_dir_all(&cache_dir)?;
Ok(cache_dir)
}
/// Download a file from URL to destination
fn download_file(url: &str, dest: &Path) -> Result<(), Box<dyn Error>> {
println!("Downloading model from {}...", url);
let client = reqwest::blocking::Client::builder()
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36")
.build()?;
let mut response = client.get(url).send()?;
if !response.status().is_success() {
return Err(format!("Failed to download model: HTTP {}", response.status()).into());
}
let mut file = fs::File::create(dest)?;
let total_size = response.content_length().unwrap_or(0);
let mut downloaded = 0u64;
let mut buffer = vec![0; 8192];
loop {
let bytes_read = std::io::Read::read(&mut response, &mut buffer)?;
if bytes_read == 0 {
break;
}
file.write_all(&buffer[..bytes_read])?;
downloaded += bytes_read as u64;
if total_size > 0 {
let progress = (downloaded as f64 / total_size as f64 * 100.0) as u32;
print!("\rDownloading... {}%", progress);
std::io::stdout().flush()?;
}
}
if total_size > 0 {
println!("\rDownload complete! ");
} else {
println!("Download complete! ({} bytes)", downloaded);
}
Ok(())
}
/// Verify file SHA256 hash
fn verify_hash(file_path: &Path, expected_hash: &str) -> Result<bool, Box<dyn Error>> {
let mut file = fs::File::open(file_path)?;
let mut hasher = Sha256::new();
std::io::copy(&mut file, &mut hasher)?;
let hash = hasher.finalize();
let hash_str = format!("{:x}", hash);
Ok(hash_str == expected_hash)
}
/// Get or download model, return path to model file
pub fn get_model_path(model_info: &ModelInfo, custom_path: Option<&str>, offline: bool) -> Result<PathBuf, Box<dyn Error>> {
// If custom path provided, use it
if let Some(path) = custom_path {
let model_path = PathBuf::from(path);
if !model_path.exists() {
return Err(format!("Custom model path does not exist: {}", path).into());
}
println!("Using custom model: {}", path);
return Ok(model_path);
}
// Check cache
let cache_dir = get_cache_dir()?;
let model_filename = format!("{}.onnx", model_info.name);
let model_path = cache_dir.join(&model_filename);
if model_path.exists() {
println!("Using cached model: {}", model_path.display());
// Verify hash if provided
if let Some(expected_hash) = model_info.sha256 {
print!("Verifying model integrity... ");
std::io::stdout().flush()?;
if verify_hash(&model_path, expected_hash)? {
println!("OK");
} else {
println!("FAILED");
return Err("Model hash verification failed. Try deleting the cached model.".into());
}
}
return Ok(model_path);
}
// Download if not in offline mode
if offline {
return Err(format!(
"Model not found in cache and offline mode is enabled. Cache path: {}",
model_path.display()
).into());
}
println!("Model not found in cache, downloading...");
download_file(model_info.url, &model_path)?;
// Verify after download
if let Some(expected_hash) = model_info.sha256 {
print!("Verifying downloaded model... ");
std::io::stdout().flush()?;
if verify_hash(&model_path, expected_hash)? {
println!("OK");
} else {
// Delete corrupted file
fs::remove_file(&model_path)?;
return Err("Downloaded model hash verification failed".into());
}
}
Ok(model_path)
}
/// Create an ONNX Runtime session from model path
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
println!("Loading model into ONNX Runtime...");
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?
.commit_from_file(model_path)?;
println!("Model loaded successfully!");
Ok(session)
}

149
src/postprocessing.rs Normal file
View file

@ -0,0 +1,149 @@
use image::{DynamicImage, RgbaImage, Rgba, GenericImageView, imageops::FilterType};
use ndarray::Array4;
use std::error::Error;
/// Apply mask to original image to remove background
///
/// Steps:
/// 1. Extract mask from output tensor (shape: [1, 1, H, W])
/// 2. Resize mask to match original image dimensions
/// 3. Apply mask as alpha channel to create RGBA image
pub fn apply_mask(
original: &DynamicImage,
mask_tensor: Array4<f32>,
) -> Result<RgbaImage, Box<dyn Error>> {
println!("Applying mask to remove background...");
let (orig_width, orig_height) = original.dimensions();
// Extract mask dimensions
let mask_shape = mask_tensor.shape();
if mask_shape[0] != 1 || mask_shape[1] != 1 {
return Err(format!(
"Expected mask shape [1, 1, H, W], got {:?}",
mask_shape
).into());
}
let mask_height = mask_shape[2] as u32;
let mask_width = mask_shape[3] as u32;
println!("Mask dimensions: {}x{}", mask_width, mask_height);
println!("Original dimensions: {}x{}", orig_width, orig_height);
// Create a grayscale image from the mask
let mut mask_image = image::GrayImage::new(mask_width, mask_height);
for y in 0..mask_height {
for x in 0..mask_width {
let mask_value = mask_tensor[[0, 0, y as usize, x as usize]];
// Clamp and convert to u8
let pixel_value = (mask_value.clamp(0.0, 1.0) * 255.0) as u8;
mask_image.put_pixel(x, y, image::Luma([pixel_value]));
}
}
// Resize mask to match original image dimensions if needed
let resized_mask = if mask_width != orig_width || mask_height != orig_height {
println!("Resizing mask to match original image...");
image::imageops::resize(
&mask_image,
orig_width,
orig_height,
FilterType::Lanczos3,
)
} else {
mask_image
};
// Convert original to RGBA and apply mask
let rgba_original = original.to_rgba8();
let mut result = RgbaImage::new(orig_width, orig_height);
for y in 0..orig_height {
for x in 0..orig_width {
let orig_pixel = rgba_original.get_pixel(x, y);
let mask_pixel = resized_mask.get_pixel(x, y);
// Apply mask as alpha channel
result.put_pixel(
x,
y,
Rgba([
orig_pixel[0], // R
orig_pixel[1], // G
orig_pixel[2], // B
mask_pixel[0], // Alpha from mask
]),
);
}
}
println!("Background removal complete!");
Ok(result)
}
/// Create a side-by-side comparison image
pub fn create_side_by_side(
original: &DynamicImage,
result: &RgbaImage,
) -> Result<RgbaImage, Box<dyn Error>> {
let (width, height) = original.dimensions();
let mut composite = RgbaImage::new(width * 2, height);
// Left side: original image
let original_rgba = original.to_rgba8();
for y in 0..height {
for x in 0..width {
composite.put_pixel(x, y, *original_rgba.get_pixel(x, y));
}
}
// Right side: result with checkered background for transparency
for y in 0..height {
for x in 0..width {
let result_pixel = result.get_pixel(x, y);
let alpha = result_pixel[3] as f32 / 255.0;
// Create checkered background (8x8 squares)
let checker_size = 8;
let is_light = ((x / checker_size) + (y / checker_size)) % 2 == 0;
let bg_color = if is_light { 200 } else { 150 };
// Alpha blend with checkered background
let final_r = (result_pixel[0] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8;
composite.put_pixel(
x + width,
y,
Rgba([final_r, final_g, final_b, 255]),
);
}
}
Ok(composite)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array4;
#[test]
fn test_apply_mask_shape() {
let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100));
let mask = Array4::<f32>::ones((1, 1, 100, 100));
let result = apply_mask(&img, mask).unwrap();
assert_eq!(result.dimensions(), (100, 100));
}
#[test]
fn test_side_by_side_dimensions() {
let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100));
let result_img = RgbaImage::new(100, 100);
let composite = create_side_by_side(&img, &result_img).unwrap();
assert_eq!(composite.dimensions(), (200, 100)); // Double width
}
}

67
src/preprocessing.rs Normal file
View file

@ -0,0 +1,67 @@
use image::{DynamicImage, imageops::FilterType};
use ndarray::Array4;
use std::error::Error;
/// Preprocess an image for the BiRefNet model
///
/// Steps:
/// 1. Resize to target dimensions (1024x1024)
/// 2. Convert from u8 [0, 255] to f32 [0.0, 1.0]
/// 3. Rearrange from HWC (Height, Width, Channels) to CHW format
/// 4. Add batch dimension: [1, 3, H, W]
pub fn preprocess_image(
img: &DynamicImage,
target_width: u32,
target_height: u32,
) -> Result<Array4<f32>, Box<dyn Error>> {
println!("Preprocessing image...");
// Step 1: Resize image
let resized = img.resize_exact(target_width, target_height, FilterType::Lanczos3);
let rgb_image = resized.to_rgb8();
let (width, height) = rgb_image.dimensions();
// Step 2: Create ndarray with shape [1, 3, height, width]
let mut array = Array4::<f32>::zeros((1, 3, height as usize, width as usize));
// Step 3: Fill the array, converting from HWC to CHW and normalizing
for y in 0..height {
for x in 0..width {
let pixel = rgb_image.get_pixel(x, y);
// Normalize from [0, 255] to [0.0, 1.0]
array[[0, 0, y as usize, x as usize]] = pixel[0] as f32 / 255.0; // R
array[[0, 1, y as usize, x as usize]] = pixel[1] as f32 / 255.0; // G
array[[0, 2, y as usize, x as usize]] = pixel[2] as f32 / 255.0; // B
}
}
println!("Preprocessing complete. Tensor shape: {:?}", array.shape());
Ok(array)
}
#[cfg(test)]
mod tests {
use super::*;
use image::RgbImage;
#[test]
fn test_preprocess_shape() {
let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100));
let result = preprocess_image(&img, 1024, 1024).unwrap();
assert_eq!(result.shape(), &[1, 3, 1024, 1024]);
}
#[test]
fn test_preprocess_normalization() {
let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100));
let result = preprocess_image(&img, 1024, 1024).unwrap();
// Check that all values are in [0.0, 1.0]
for &val in result.iter() {
assert!(val >= 0.0 && val <= 1.0);
}
}
}