Compare commits
10 commits
86e0dd734d
...
b979fadb43
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b979fadb43 | ||
|
|
1a338de3c0 | ||
|
|
596cd28dcc | ||
|
|
6e965c4a18 | ||
|
|
609afde727 | ||
|
|
22c9315b96 | ||
|
|
9e9721cbbc | ||
|
|
f6961c3177 | ||
|
|
1fae113c5c | ||
|
|
75dd44f27a |
8 changed files with 83 additions and 29 deletions
1
.envrc
Normal file
1
.envrc
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
use flake
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -1 +1,2 @@
|
||||||
/target
|
/target
|
||||||
|
.direnv
|
||||||
|
|
|
||||||
9
.vscode/extensions.json
vendored
Normal file
9
.vscode/extensions.json
vendored
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
{
|
||||||
|
"recommendations": [
|
||||||
|
"fill-labs.dependi",
|
||||||
|
"jnoortheen.nix-ide",
|
||||||
|
"mkhl.direnv",
|
||||||
|
"rust-lang.rust-analyzer",
|
||||||
|
"tamasfe.even-better-toml",
|
||||||
|
]
|
||||||
|
}
|
||||||
12
flake.lock
12
flake.lock
|
|
@ -20,11 +20,11 @@
|
||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1768886240,
|
"lastModified": 1769789167,
|
||||||
"narHash": "sha256-C2TjvwYZ2VDxYWeqvvJ5XPPp6U7H66zeJlRaErJKoEM=",
|
"narHash": "sha256-kKB3bqYJU5nzYeIROI82Ef9VtTbu4uA3YydSk/Bioa8=",
|
||||||
"owner": "NixOS",
|
"owner": "NixOS",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "80e4adbcf8992d3fd27ad4964fbb84907f9478b0",
|
"rev": "62c8382960464ceb98ea593cb8321a2cf8f9e3e5",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|
@ -62,11 +62,11 @@
|
||||||
"nixpkgs": "nixpkgs_2"
|
"nixpkgs": "nixpkgs_2"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1768963622,
|
"lastModified": 1769915446,
|
||||||
"narHash": "sha256-n6VHiUgrYD9yjagzG6ncVVqFbVTsKCI54tR9PNAFCo0=",
|
"narHash": "sha256-f1F/umtX3ZD7fF9DHSloVHc0mnAT0ry0YK2jI/6E0aI=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "2ef5b3362af585a83bafd34e7fc9b1f388c2e5e2",
|
"rev": "bc00300f010275e46feb3c3974df6587ff7b7808",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,9 @@
|
||||||
overlays = [ (import rust-overlay) ];
|
overlays = [ (import rust-overlay) ];
|
||||||
pkgs = import nixpkgs { inherit system overlays; };
|
pkgs = import nixpkgs { inherit system overlays; };
|
||||||
stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv;
|
stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv;
|
||||||
rustToolchain = pkgs.rust-bin.stable.latest.default;
|
rustToolchain = pkgs.rust-bin.stable.latest.default.override {
|
||||||
|
extensions = [ "rust-src" ];
|
||||||
|
};
|
||||||
|
|
||||||
nativeBuildInputs = with pkgs; [
|
nativeBuildInputs = with pkgs; [
|
||||||
rustToolchain
|
rustToolchain
|
||||||
|
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
[toolchain]
|
|
||||||
channel = "stable"
|
|
||||||
44
src/main.rs
44
src/main.rs
|
|
@ -9,7 +9,7 @@ mod model;
|
||||||
mod postprocessing;
|
mod postprocessing;
|
||||||
mod preprocessing;
|
mod preprocessing;
|
||||||
|
|
||||||
use model::{ModelInfo, create_session, get_model_path};
|
use model::{Model, create_session, get_model_path};
|
||||||
use postprocessing::{apply_mask, create_side_by_side};
|
use postprocessing::{apply_mask, create_side_by_side};
|
||||||
use preprocessing::preprocess_image;
|
use preprocessing::preprocess_image;
|
||||||
|
|
||||||
|
|
@ -20,7 +20,17 @@ struct Args {
|
||||||
#[arg(help = "URL of the image to download and process")]
|
#[arg(help = "URL of the image to download and process")]
|
||||||
url: String,
|
url: String,
|
||||||
|
|
||||||
#[arg(long, help = "Path to custom ONNX model")]
|
#[arg(
|
||||||
|
short,
|
||||||
|
long,
|
||||||
|
group = "model_selection",
|
||||||
|
value_enum,
|
||||||
|
default_value_t = Model::Bria,
|
||||||
|
help = "Model to use: 'bria' or 'birefnet-lite' (mutually exclusive with --model-path)"
|
||||||
|
)]
|
||||||
|
model: Model,
|
||||||
|
|
||||||
|
#[arg(long, group = "model_selection", help = "Path to custom ONNX model")]
|
||||||
model_path: Option<String>,
|
model_path: Option<String>,
|
||||||
|
|
||||||
#[arg(long, help = "Skip model download, fail if not cached")]
|
#[arg(long, help = "Skip model download, fail if not cached")]
|
||||||
|
|
@ -50,25 +60,35 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
|
||||||
// Load the image
|
// Load the image
|
||||||
let img = image::load_from_memory(&bytes)?;
|
let img = image::load_from_memory(&bytes)?;
|
||||||
let (width, height) = img.dimensions();
|
let (img_width, img_height) = img.dimensions();
|
||||||
println!("Loaded image: {}x{}\n", width, height);
|
println!("Loaded image: {}x{}\n", img_width, img_height);
|
||||||
|
|
||||||
// Get model info
|
// Get model path - either from custom path or by selecting built-in model
|
||||||
let model_info = ModelInfo::BIREFNET_LITE;
|
let model_path = if let Some(custom_path) = args.model_path {
|
||||||
|
// Custom model path provided, use it directly
|
||||||
|
let model_path = std::path::PathBuf::from(&custom_path);
|
||||||
|
if !model_path.exists() {
|
||||||
|
return Err(format!("Custom model path does not exist: {}", custom_path).into());
|
||||||
|
}
|
||||||
|
println!("Using custom model: {}", custom_path);
|
||||||
|
model_path
|
||||||
|
} else {
|
||||||
|
// Use built-in model
|
||||||
|
let model_info = args.model.info();
|
||||||
println!("Using model: {}", model_info.name);
|
println!("Using model: {}", model_info.name);
|
||||||
|
get_model_path(&model_info, None, args.offline)?
|
||||||
// Get or download model
|
};
|
||||||
let model_path = get_model_path(&model_info, args.model_path.as_deref(), args.offline)?;
|
|
||||||
|
|
||||||
// Create ONNX Runtime session
|
// Create ONNX Runtime session
|
||||||
let mut session = create_session(&model_path)?;
|
let mut session = create_session(&model_path)?;
|
||||||
|
let input_name = session.inputs()[0].name().to_string();
|
||||||
|
|
||||||
// Preprocess image
|
// Preprocess
|
||||||
let input_tensor = preprocess_image(&img, model_info.input_size.0, model_info.input_size.1)?;
|
let input_tensor = preprocess_image(&img, 1024, 1024)?;
|
||||||
|
|
||||||
// Run inference
|
// Run inference
|
||||||
println!("Running inference...");
|
println!("Running inference...");
|
||||||
let outputs = session.run(ort::inputs!["input" => Tensor::from_array(input_tensor)?])?;
|
let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?;
|
||||||
|
|
||||||
// Extract mask output
|
// Extract mask output
|
||||||
let mask_output = &outputs[0];
|
let mask_output = &outputs[0];
|
||||||
|
|
|
||||||
37
src/model.rs
37
src/model.rs
|
|
@ -1,3 +1,5 @@
|
||||||
|
use clap::ValueEnum;
|
||||||
|
use ort::execution_providers::CUDAExecutionProvider;
|
||||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
|
@ -5,21 +7,41 @@ use std::fs;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
#[clap(rename_all = "kebab-case")]
|
||||||
|
pub enum Model {
|
||||||
|
BiRefNetLite,
|
||||||
|
Bria,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn info(&self) -> &ModelInfo {
|
||||||
|
match self {
|
||||||
|
Model::BiRefNetLite => &ModelInfo::BIREFNET_LITE,
|
||||||
|
Model::Bria => &ModelInfo::BRIA,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Model metadata
|
/// Model metadata
|
||||||
pub struct ModelInfo {
|
pub struct ModelInfo {
|
||||||
pub name: &'static str,
|
pub name: &'static str,
|
||||||
pub url: &'static str,
|
pub url: &'static str,
|
||||||
pub sha256: Option<&'static str>,
|
pub sha256: Option<&'static str>,
|
||||||
pub input_size: (u32, u32),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelInfo {
|
impl ModelInfo {
|
||||||
/// BiRefNet General Lite model
|
/// BiRefNet General Lite model
|
||||||
pub const BIREFNET_LITE: ModelInfo = ModelInfo {
|
pub const BIREFNET_LITE: ModelInfo = ModelInfo {
|
||||||
name: "birefnet-general-lite",
|
name: "birefnet-general-lite",
|
||||||
url: "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png",
|
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
|
||||||
sha256: None, // We'll skip verification for now
|
sha256: Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333"),
|
||||||
input_size: (1024, 1024),
|
};
|
||||||
|
|
||||||
|
pub const BRIA: ModelInfo = ModelInfo {
|
||||||
|
name: "bria",
|
||||||
|
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx",
|
||||||
|
sha256: Some("5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958"),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -156,16 +178,17 @@ pub fn get_model_path(
|
||||||
Ok(model_path)
|
Ok(model_path)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an ONNX Runtime session from model path
|
/// Create an ONNX Runtime session from model path with CUDA backend
|
||||||
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
||||||
println!("Loading model into ONNX Runtime...");
|
println!("Loading model into ONNX Runtime with CUDA backend...");
|
||||||
|
|
||||||
let session = Session::builder()?
|
let session = Session::builder()?
|
||||||
|
.with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||||
.with_intra_threads(4)?
|
.with_intra_threads(4)?
|
||||||
.commit_from_file(model_path)?;
|
.commit_from_file(model_path)?;
|
||||||
|
|
||||||
println!("Model loaded successfully!");
|
println!("Model loaded successfully with CUDA!");
|
||||||
|
|
||||||
Ok(session)
|
Ok(session)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue