From 296a0172f18b975bbb0c0a2d30ab5c1cb37cd9fa Mon Sep 17 00:00:00 2001 From: Matthew Deville Date: Thu, 22 Jan 2026 00:25:11 +0100 Subject: [PATCH] compiles so far --- Cargo.lock | 371 +++++++++++++++++++++++++++++++++++++++++- Cargo.toml | 3 + README.md | 103 ++++++++++++ flake.nix | 1 + src/main.rs | 86 +++++++++- src/model.rs | 162 ++++++++++++++++++ src/postprocessing.rs | 149 +++++++++++++++++ src/preprocessing.rs | 67 ++++++++ 8 files changed, 930 insertions(+), 12 deletions(-) create mode 100644 README.md create mode 100644 src/model.rs create mode 100644 src/postprocessing.rs create mode 100644 src/preprocessing.rs diff --git a/Cargo.lock b/Cargo.lock index 66ecc55..ad4810e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -303,6 +303,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" + [[package]] name = "bit-set" version = "0.5.3" @@ -351,6 +357,15 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "block-sys" version = "0.1.0-beta.1" @@ -388,6 +403,12 @@ version = "1.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "byteorder-lite" version = "0.1.0" @@ -596,6 +617,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -636,6 +666,16 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "d3d12" version = "0.7.0" @@ -647,6 +687,26 @@ dependencies = [ "winapi", ] +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "dispatch" version = "0.2.0" @@ -726,6 +786,16 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "exr" version = "1.74.0" @@ -741,6 +811,12 @@ dependencies = [ "zune-inflate", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "fax" version = "0.2.6" @@ -932,6 +1008,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.17" @@ -1124,6 +1210,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df" +[[package]] +name = "hmac-sha256" +version = "1.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad6880c8d4a9ebf39c6e8b77007ce223f646a4d21ce29d99f70cb16420545425" + [[package]] name = "http" version = "1.4.0" @@ -1561,6 +1653,12 @@ dependencies = [ "redox_syscall 0.7.0", ] +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "litemap" version = "0.8.1" @@ -1597,6 +1695,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "lzma-rust2" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1670343e58806300d87950e3401e820b519b9384281bbabfb15e3636689ffd69" + [[package]] name = "malloc_buf" version = "0.0.6" @@ -1606,6 +1710,16 @@ dependencies = [ "libc", ] +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "maybe-rayon" version = "0.1.1" @@ -1725,6 +1839,38 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe 0.1.6", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndk" version = "0.7.0" @@ -1810,6 +1956,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-derive" version = "0.4.2" @@ -1958,12 +2113,56 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags 2.10.0", + "cfg-if", + "foreign-types 0.3.2", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + [[package]] name = "openssl-probe" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "orbclient" version = "0.3.50" @@ -1974,6 +2173,30 @@ dependencies = [ "libredox", ] +[[package]] +name = "ort" +version = "2.0.0-rc.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5df903c0d2c07b56950f1058104ab0c8557159f2741782223704de9be73c3c" +dependencies = [ + "ndarray", + "ort-sys", + "smallvec", + "tracing", + "ureq", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06503bb33f294c5f1ba484011e053bfa6ae227074bdb841e9863492dc5960d4b" +dependencies = [ + "hmac-sha256", + "lzma-rust2", + "ureq", +] + [[package]] name = "owned_ttf_parser" version = "0.25.1" @@ -2018,6 +2241,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec" +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -2078,6 +2310,21 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "portable-atomic" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.4" @@ -2320,6 +2567,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.11.0" @@ -2373,7 +2626,10 @@ version = "0.1.0" dependencies = [ "clap", "image", + "ndarray", + "ort", "reqwest", + "sha2", "show-image", ] @@ -2470,6 +2726,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags 2.10.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + [[package]] name = "rustls" version = "0.23.36" @@ -2490,10 +2759,10 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" dependencies = [ - "openssl-probe", + "openssl-probe 0.2.1", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.5.1", ] [[package]] @@ -2521,7 +2790,7 @@ dependencies = [ "rustls-native-certs", "rustls-platform-verifier-android", "rustls-webpki", - "security-framework", + "security-framework 3.5.1", "security-framework-sys", "webpki-root-certs", "windows-sys 0.61.2", @@ -2594,6 +2863,19 @@ dependencies = [ "tiny-skia", ] +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.10.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + [[package]] name = "security-framework" version = "3.5.1" @@ -2653,6 +2935,17 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.3.0" @@ -2751,6 +3044,17 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "spirv" version = "0.2.0+1.5.4" @@ -2854,6 +3158,19 @@ dependencies = [ "libc", ] +[[package]] +name = "tempfile" +version = "3.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "termcolor" version = "1.4.1" @@ -3097,6 +3414,12 @@ version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2df906b07856748fa3f6e0ad0cbaa047052d4a7dd609e231c4f72cee8c36f31" +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicode-ident" version = "1.0.22" @@ -3121,6 +3444,36 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "3.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a" +dependencies = [ + "base64", + "der", + "log", + "native-tls", + "percent-encoding", + "rustls-pki-types", + "socks", + "ureq-proto", + "utf-8", + "webpki-root-certs", +] + +[[package]] +name = "ureq-proto" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" +dependencies = [ + "base64", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.8" @@ -3133,6 +3486,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -3156,6 +3515,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "vec_map" version = "0.8.2" diff --git a/Cargo.toml b/Cargo.toml index 0207bf8..0ffd4ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,5 +6,8 @@ [dependencies] clap = { version = "4.5", features = ["derive"] } image = "0.25" + ndarray = "0.17" + ort = "=2.0.0-rc.11" reqwest = { version = "0.13", features = ["blocking"] } + sha2 = "0.10" show-image = { version = "0.14", features = ["image"] } diff --git a/README.md b/README.md new file mode 100644 index 0000000..20ff475 --- /dev/null +++ b/README.md @@ -0,0 +1,103 @@ +# Background Removal Tool + +A Rust-based background removal tool using ONNX models (BiRefNet) with GPU support. + +## Features + +- Download images from URLs +- Automatic ONNX model download and caching +- Background removal using BiRefNet-general-lite model +- Side-by-side comparison display (original vs processed) +- Checkered background visualization for transparency +- CLI with flexible options + +## Building + +### With Nix (Recommended) + +```bash +# Enter the development environment +nix develop + +# Build the project +cargo build --release +``` + +### Without Nix + +If you don't use Nix, you'll need to ensure you have the following system dependencies: +- Rust toolchain +- OpenSSL development libraries (`libssl-dev` on Ubuntu/Debian) +- pkg-config +- X11 libraries (for window display) +- Vulkan loader + +## Usage + +Basic usage: + +```bash +# Download and process an image +cargo run -- https://example.com/image.jpg + +# Use custom model +cargo run -- https://example.com/image.jpg --model-path ./custom_model.onnx + +# Offline mode (use only cached model) +cargo run -- https://example.com/image.jpg --offline +``` + +## How It Works + +1. **Download**: Fetches the image from the provided URL +2. **Model Loading**: Auto-downloads BiRefNet model (cached in `~/.cache/remove_background/models/`) +3. **Preprocessing**: Resizes image to 1024x1024, normalizes pixels, converts to CHW format +4. **Inference**: Runs ONNX model to generate foreground mask +5. **Postprocessing**: Applies mask to original image, creates RGBA output +6. **Display**: Shows side-by-side comparison with checkered background + +## Project Structure + +``` +src/ +├── main.rs - CLI interface and orchestration +├── model.rs - Model download and ONNX session management +├── preprocessing.rs - Image preprocessing (resize, normalize, CHW conversion) +└── postprocessing.rs - Mask application and side-by-side display creation +``` + +## Dependencies + +- **ort**: ONNX Runtime Rust bindings +- **ndarray**: N-dimensional array operations +- **image**: Image loading and manipulation +- **reqwest**: HTTP client for downloads +- **show-image**: Cross-platform window display +- **clap**: Command-line argument parsing +- **sha2**: Hash verification + +## Future Enhancements + +- [ ] CUDA execution provider support (GPU acceleration) +- [ ] Batch processing support +- [ ] Additional model options (U2-Net, ISNet, etc.) +- [ ] Save output to file +- [ ] Local file input support +- [ ] Progress bars for downloads +- [ ] WebAssembly support + +## Controls + +- **ESC**: Close the window and exit + +## Model Information + +Default model: **BiRefNet-general-lite** +- Size: ~50MB +- Input: 1024x1024 RGB +- Output: 1024x1024 mask +- Source: https://huggingface.co/ZhengPeng7/BiRefNet + +## License + +See LICENSE file for details. diff --git a/flake.nix b/flake.nix index ec8b92e..a7257a5 100644 --- a/flake.nix +++ b/flake.nix @@ -44,6 +44,7 @@ # Graphics/rendering libraries needed for wgpu (used by show-image) graphicsBuildInputs = with pkgs; [ vulkan-loader + openssl ]; buildInputs = xorgBuildInputs ++ waylandBuildInputs ++ graphicsBuildInputs; diff --git a/src/main.rs b/src/main.rs index 1c73b73..0530f46 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,29 @@ use clap::Parser; use image::GenericImageView; +use ort::value::Tensor; use show_image::{AsImageView, create_window, event}; use std::error::Error; +mod model; +mod postprocessing; +mod preprocessing; + +use model::{ModelInfo, create_session, get_model_path}; +use postprocessing::{apply_mask, create_side_by_side}; +use preprocessing::preprocess_image; + #[derive(Parser)] #[command(name = "remove_background")] -#[command(about = "Download and display an image (future: background removal)", long_about = None)] +#[command(about = "Remove background from images using ONNX models", long_about = None)] struct Args { - #[arg(help = "URL of the image to download and display")] + #[arg(help = "URL of the image to download and process")] url: String, + + #[arg(long, help = "Path to custom ONNX model")] + model_path: Option, + + #[arg(long, help = "Skip model download, fail if not cached")] + offline: bool, } #[show_image::main] @@ -16,6 +31,7 @@ fn main() -> Result<(), Box> { // Parse command line arguments let args = Args::parse(); + println!("=== Background Removal Tool ===\n"); println!("Downloading image from: {}", args.url); // Download the image with a user agent (using blocking client) @@ -34,17 +50,69 @@ fn main() -> Result<(), Box> { // Load the image let img = image::load_from_memory(&bytes)?; let (width, height) = img.dimensions(); - println!("Loaded image: {}x{}", width, height); + println!("Loaded image: {}x{}\n", width, height); - // Create a window and display the image. - let window = create_window("image", Default::default())?; + // Get model info + let model_info = ModelInfo::BIREFNET_LITE; + println!("Using model: {}", model_info.name); + + // Get or download model + let model_path = get_model_path(&model_info, args.model_path.as_deref(), args.offline)?; + + // Create ONNX Runtime session + let mut session = create_session(&model_path)?; + + // Preprocess image + let input_tensor = preprocess_image(&img, model_info.input_size.0, model_info.input_size.1)?; + + // Run inference + println!("Running inference..."); + let outputs = session.run(ort::inputs!["input" => Tensor::from_array(input_tensor)?])?; + + // Extract mask output + let mask_output = &outputs[0]; + let (mask_shape, mask_array) = mask_output.try_extract_tensor::()?; // Returns (shape, &[f32]) + + /*let mask_tensor = mask_array.into_dimensionality::()?.to_owned(); + + println!( + "Inference complete! Output shape: {:?}", + mask_tensor.shape() + ); + + // Apply mask to remove background + let result_rgba = apply_mask(&img, mask_tensor)?; + + // Create side-by-side comparison + println!("Creating side-by-side comparison..."); + let composite = create_side_by_side(&img, &result_rgba)?; + let composite_dynamic = image::DynamicImage::ImageRgba8(composite); + + // Display the result + let (comp_width, comp_height) = composite_dynamic.dimensions(); + let window = create_window( + "Background Removal - Original (Left) vs Result (Right) - Press ESC to close", + Default::default(), + )?; window.set_image( - "downloaded image", - &img.as_image_view().map_err(|e| e.to_string())?, + "comparison", + &composite_dynamic + .as_image_view() + .map_err(|e| e.to_string())?, )?; - println!("Image displayed. Press ESC to close the window."); - + println!("\n=== Done! ==="); + println!( + "Displaying side-by-side comparison ({}x{}):", + comp_width, comp_height + ); + println!(" Left: Original image"); + println!(" Right: Background removed (shown on checkered background)"); + println!("\nPress ESC to close the window."); */ + let window = create_window( + "Background Removal - Original (Left) vs Result (Right) - Press ESC to close", + Default::default(), + )?; // Event loop - wait for ESC key to close for event in window.event_channel()? { if let event::WindowEvent::KeyboardInput(event) = event { diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..a922313 --- /dev/null +++ b/src/model.rs @@ -0,0 +1,162 @@ +use ort::session::{Session, builder::GraphOptimizationLevel}; +use sha2::{Sha256, Digest}; +use std::error::Error; +use std::fs; +use std::io::Write; +use std::path::{Path, PathBuf}; + +/// Model metadata +pub struct ModelInfo { + pub name: &'static str, + pub url: &'static str, + pub sha256: Option<&'static str>, + pub input_size: (u32, u32), +} + +impl ModelInfo { + /// BiRefNet General Lite model + pub const BIREFNET_LITE: ModelInfo = ModelInfo { + name: "birefnet-general-lite", + url: "https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/BiRefNet-general-lite.onnx", + sha256: None, // We'll skip verification for now + input_size: (1024, 1024), + }; +} + +/// Get the cache directory for models +fn get_cache_dir() -> Result> { + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE"))?; + let cache_dir = Path::new(&home).join(".cache").join("remove_background").join("models"); + fs::create_dir_all(&cache_dir)?; + Ok(cache_dir) +} + +/// Download a file from URL to destination +fn download_file(url: &str, dest: &Path) -> Result<(), Box> { + println!("Downloading model from {}...", url); + + let client = reqwest::blocking::Client::builder() + .user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36") + .build()?; + + let mut response = client.get(url).send()?; + + if !response.status().is_success() { + return Err(format!("Failed to download model: HTTP {}", response.status()).into()); + } + + let mut file = fs::File::create(dest)?; + let total_size = response.content_length().unwrap_or(0); + let mut downloaded = 0u64; + + let mut buffer = vec![0; 8192]; + loop { + let bytes_read = std::io::Read::read(&mut response, &mut buffer)?; + if bytes_read == 0 { + break; + } + file.write_all(&buffer[..bytes_read])?; + downloaded += bytes_read as u64; + + if total_size > 0 { + let progress = (downloaded as f64 / total_size as f64 * 100.0) as u32; + print!("\rDownloading... {}%", progress); + std::io::stdout().flush()?; + } + } + + if total_size > 0 { + println!("\rDownload complete! "); + } else { + println!("Download complete! ({} bytes)", downloaded); + } + + Ok(()) +} + +/// Verify file SHA256 hash +fn verify_hash(file_path: &Path, expected_hash: &str) -> Result> { + let mut file = fs::File::open(file_path)?; + let mut hasher = Sha256::new(); + std::io::copy(&mut file, &mut hasher)?; + let hash = hasher.finalize(); + let hash_str = format!("{:x}", hash); + Ok(hash_str == expected_hash) +} + +/// Get or download model, return path to model file +pub fn get_model_path(model_info: &ModelInfo, custom_path: Option<&str>, offline: bool) -> Result> { + // If custom path provided, use it + if let Some(path) = custom_path { + let model_path = PathBuf::from(path); + if !model_path.exists() { + return Err(format!("Custom model path does not exist: {}", path).into()); + } + println!("Using custom model: {}", path); + return Ok(model_path); + } + + // Check cache + let cache_dir = get_cache_dir()?; + let model_filename = format!("{}.onnx", model_info.name); + let model_path = cache_dir.join(&model_filename); + + if model_path.exists() { + println!("Using cached model: {}", model_path.display()); + + // Verify hash if provided + if let Some(expected_hash) = model_info.sha256 { + print!("Verifying model integrity... "); + std::io::stdout().flush()?; + if verify_hash(&model_path, expected_hash)? { + println!("OK"); + } else { + println!("FAILED"); + return Err("Model hash verification failed. Try deleting the cached model.".into()); + } + } + + return Ok(model_path); + } + + // Download if not in offline mode + if offline { + return Err(format!( + "Model not found in cache and offline mode is enabled. Cache path: {}", + model_path.display() + ).into()); + } + + println!("Model not found in cache, downloading..."); + download_file(model_info.url, &model_path)?; + + // Verify after download + if let Some(expected_hash) = model_info.sha256 { + print!("Verifying downloaded model... "); + std::io::stdout().flush()?; + if verify_hash(&model_path, expected_hash)? { + println!("OK"); + } else { + // Delete corrupted file + fs::remove_file(&model_path)?; + return Err("Downloaded model hash verification failed".into()); + } + } + + Ok(model_path) +} + +/// Create an ONNX Runtime session from model path +pub fn create_session(model_path: &Path) -> Result> { + println!("Loading model into ONNX Runtime..."); + + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(4)? + .commit_from_file(model_path)?; + + println!("Model loaded successfully!"); + + Ok(session) +} diff --git a/src/postprocessing.rs b/src/postprocessing.rs new file mode 100644 index 0000000..ed7b11e --- /dev/null +++ b/src/postprocessing.rs @@ -0,0 +1,149 @@ +use image::{DynamicImage, RgbaImage, Rgba, GenericImageView, imageops::FilterType}; +use ndarray::Array4; +use std::error::Error; + +/// Apply mask to original image to remove background +/// +/// Steps: +/// 1. Extract mask from output tensor (shape: [1, 1, H, W]) +/// 2. Resize mask to match original image dimensions +/// 3. Apply mask as alpha channel to create RGBA image +pub fn apply_mask( + original: &DynamicImage, + mask_tensor: Array4, +) -> Result> { + println!("Applying mask to remove background..."); + + let (orig_width, orig_height) = original.dimensions(); + + // Extract mask dimensions + let mask_shape = mask_tensor.shape(); + if mask_shape[0] != 1 || mask_shape[1] != 1 { + return Err(format!( + "Expected mask shape [1, 1, H, W], got {:?}", + mask_shape + ).into()); + } + + let mask_height = mask_shape[2] as u32; + let mask_width = mask_shape[3] as u32; + + println!("Mask dimensions: {}x{}", mask_width, mask_height); + println!("Original dimensions: {}x{}", orig_width, orig_height); + + // Create a grayscale image from the mask + let mut mask_image = image::GrayImage::new(mask_width, mask_height); + for y in 0..mask_height { + for x in 0..mask_width { + let mask_value = mask_tensor[[0, 0, y as usize, x as usize]]; + // Clamp and convert to u8 + let pixel_value = (mask_value.clamp(0.0, 1.0) * 255.0) as u8; + mask_image.put_pixel(x, y, image::Luma([pixel_value])); + } + } + + // Resize mask to match original image dimensions if needed + let resized_mask = if mask_width != orig_width || mask_height != orig_height { + println!("Resizing mask to match original image..."); + image::imageops::resize( + &mask_image, + orig_width, + orig_height, + FilterType::Lanczos3, + ) + } else { + mask_image + }; + + // Convert original to RGBA and apply mask + let rgba_original = original.to_rgba8(); + let mut result = RgbaImage::new(orig_width, orig_height); + + for y in 0..orig_height { + for x in 0..orig_width { + let orig_pixel = rgba_original.get_pixel(x, y); + let mask_pixel = resized_mask.get_pixel(x, y); + + // Apply mask as alpha channel + result.put_pixel( + x, + y, + Rgba([ + orig_pixel[0], // R + orig_pixel[1], // G + orig_pixel[2], // B + mask_pixel[0], // Alpha from mask + ]), + ); + } + } + + println!("Background removal complete!"); + + Ok(result) +} + +/// Create a side-by-side comparison image +pub fn create_side_by_side( + original: &DynamicImage, + result: &RgbaImage, +) -> Result> { + let (width, height) = original.dimensions(); + let mut composite = RgbaImage::new(width * 2, height); + + // Left side: original image + let original_rgba = original.to_rgba8(); + for y in 0..height { + for x in 0..width { + composite.put_pixel(x, y, *original_rgba.get_pixel(x, y)); + } + } + + // Right side: result with checkered background for transparency + for y in 0..height { + for x in 0..width { + let result_pixel = result.get_pixel(x, y); + let alpha = result_pixel[3] as f32 / 255.0; + + // Create checkered background (8x8 squares) + let checker_size = 8; + let is_light = ((x / checker_size) + (y / checker_size)) % 2 == 0; + let bg_color = if is_light { 200 } else { 150 }; + + // Alpha blend with checkered background + let final_r = (result_pixel[0] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; + let final_g = (result_pixel[1] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; + let final_b = (result_pixel[2] as f32 * alpha + bg_color as f32 * (1.0 - alpha)) as u8; + + composite.put_pixel( + x + width, + y, + Rgba([final_r, final_g, final_b, 255]), + ); + } + } + + Ok(composite) +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::Array4; + + #[test] + fn test_apply_mask_shape() { + let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100)); + let mask = Array4::::ones((1, 1, 100, 100)); + let result = apply_mask(&img, mask).unwrap(); + assert_eq!(result.dimensions(), (100, 100)); + } + + #[test] + fn test_side_by_side_dimensions() { + let img = DynamicImage::ImageRgb8(image::RgbImage::new(100, 100)); + let result_img = RgbaImage::new(100, 100); + let composite = create_side_by_side(&img, &result_img).unwrap(); + assert_eq!(composite.dimensions(), (200, 100)); // Double width + } +} diff --git a/src/preprocessing.rs b/src/preprocessing.rs new file mode 100644 index 0000000..95d514c --- /dev/null +++ b/src/preprocessing.rs @@ -0,0 +1,67 @@ +use image::{DynamicImage, imageops::FilterType}; +use ndarray::Array4; +use std::error::Error; + +/// Preprocess an image for the BiRefNet model +/// +/// Steps: +/// 1. Resize to target dimensions (1024x1024) +/// 2. Convert from u8 [0, 255] to f32 [0.0, 1.0] +/// 3. Rearrange from HWC (Height, Width, Channels) to CHW format +/// 4. Add batch dimension: [1, 3, H, W] +pub fn preprocess_image( + img: &DynamicImage, + target_width: u32, + target_height: u32, +) -> Result, Box> { + println!("Preprocessing image..."); + + // Step 1: Resize image + let resized = img.resize_exact(target_width, target_height, FilterType::Lanczos3); + let rgb_image = resized.to_rgb8(); + + let (width, height) = rgb_image.dimensions(); + + // Step 2: Create ndarray with shape [1, 3, height, width] + let mut array = Array4::::zeros((1, 3, height as usize, width as usize)); + + // Step 3: Fill the array, converting from HWC to CHW and normalizing + for y in 0..height { + for x in 0..width { + let pixel = rgb_image.get_pixel(x, y); + + // Normalize from [0, 255] to [0.0, 1.0] + array[[0, 0, y as usize, x as usize]] = pixel[0] as f32 / 255.0; // R + array[[0, 1, y as usize, x as usize]] = pixel[1] as f32 / 255.0; // G + array[[0, 2, y as usize, x as usize]] = pixel[2] as f32 / 255.0; // B + } + } + + println!("Preprocessing complete. Tensor shape: {:?}", array.shape()); + + Ok(array) +} + +#[cfg(test)] +mod tests { + use super::*; + use image::RgbImage; + + #[test] + fn test_preprocess_shape() { + let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100)); + let result = preprocess_image(&img, 1024, 1024).unwrap(); + assert_eq!(result.shape(), &[1, 3, 1024, 1024]); + } + + #[test] + fn test_preprocess_normalization() { + let img = DynamicImage::ImageRgb8(RgbImage::new(100, 100)); + let result = preprocess_image(&img, 1024, 1024).unwrap(); + + // Check that all values are in [0.0, 1.0] + for &val in result.iter() { + assert!(val >= 0.0 && val <= 1.0); + } + } +}