Compare commits

..

No commits in common. "b979fadb434937de36a2ad184ff45118b9fc1bad" and "86e0dd734de1be6b2c4c5f65864a246206f67066" have entirely different histories.

8 changed files with 29 additions and 83 deletions

1
.envrc
View file

@ -1 +0,0 @@
use flake

1
.gitignore vendored
View file

@ -1,2 +1 @@
/target /target
.direnv

View file

@ -1,9 +0,0 @@
{
"recommendations": [
"fill-labs.dependi",
"jnoortheen.nix-ide",
"mkhl.direnv",
"rust-lang.rust-analyzer",
"tamasfe.even-better-toml",
]
}

View file

@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1769789167, "lastModified": 1768886240,
"narHash": "sha256-kKB3bqYJU5nzYeIROI82Ef9VtTbu4uA3YydSk/Bioa8=", "narHash": "sha256-C2TjvwYZ2VDxYWeqvvJ5XPPp6U7H66zeJlRaErJKoEM=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "62c8382960464ceb98ea593cb8321a2cf8f9e3e5", "rev": "80e4adbcf8992d3fd27ad4964fbb84907f9478b0",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -62,11 +62,11 @@
"nixpkgs": "nixpkgs_2" "nixpkgs": "nixpkgs_2"
}, },
"locked": { "locked": {
"lastModified": 1769915446, "lastModified": 1768963622,
"narHash": "sha256-f1F/umtX3ZD7fF9DHSloVHc0mnAT0ry0YK2jI/6E0aI=", "narHash": "sha256-n6VHiUgrYD9yjagzG6ncVVqFbVTsKCI54tR9PNAFCo0=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "bc00300f010275e46feb3c3974df6587ff7b7808", "rev": "2ef5b3362af585a83bafd34e7fc9b1f388c2e5e2",
"type": "github" "type": "github"
}, },
"original": { "original": {

View file

@ -20,9 +20,7 @@
overlays = [ (import rust-overlay) ]; overlays = [ (import rust-overlay) ];
pkgs = import nixpkgs { inherit system overlays; }; pkgs = import nixpkgs { inherit system overlays; };
stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv; stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv;
rustToolchain = pkgs.rust-bin.stable.latest.default.override { rustToolchain = pkgs.rust-bin.stable.latest.default;
extensions = [ "rust-src" ];
};
nativeBuildInputs = with pkgs; [ nativeBuildInputs = with pkgs; [
rustToolchain rustToolchain

2
rust-toolchain.toml Normal file
View file

@ -0,0 +1,2 @@
[toolchain]
channel = "stable"

View file

@ -9,7 +9,7 @@ mod model;
mod postprocessing; mod postprocessing;
mod preprocessing; mod preprocessing;
use model::{Model, create_session, get_model_path}; use model::{ModelInfo, create_session, get_model_path};
use postprocessing::{apply_mask, create_side_by_side}; use postprocessing::{apply_mask, create_side_by_side};
use preprocessing::preprocess_image; use preprocessing::preprocess_image;
@ -20,17 +20,7 @@ struct Args {
#[arg(help = "URL of the image to download and process")] #[arg(help = "URL of the image to download and process")]
url: String, url: String,
#[arg( #[arg(long, help = "Path to custom ONNX model")]
short,
long,
group = "model_selection",
value_enum,
default_value_t = Model::Bria,
help = "Model to use: 'bria' or 'birefnet-lite' (mutually exclusive with --model-path)"
)]
model: Model,
#[arg(long, group = "model_selection", help = "Path to custom ONNX model")]
model_path: Option<String>, model_path: Option<String>,
#[arg(long, help = "Skip model download, fail if not cached")] #[arg(long, help = "Skip model download, fail if not cached")]
@ -60,35 +50,25 @@ fn main() -> Result<(), Box<dyn Error>> {
// Load the image // Load the image
let img = image::load_from_memory(&bytes)?; let img = image::load_from_memory(&bytes)?;
let (img_width, img_height) = img.dimensions(); let (width, height) = img.dimensions();
println!("Loaded image: {}x{}\n", img_width, img_height); println!("Loaded image: {}x{}\n", width, height);
// Get model path - either from custom path or by selecting built-in model // Get model info
let model_path = if let Some(custom_path) = args.model_path { let model_info = ModelInfo::BIREFNET_LITE;
// Custom model path provided, use it directly println!("Using model: {}", model_info.name);
let model_path = std::path::PathBuf::from(&custom_path);
if !model_path.exists() { // Get or download model
return Err(format!("Custom model path does not exist: {}", custom_path).into()); let model_path = get_model_path(&model_info, args.model_path.as_deref(), args.offline)?;
}
println!("Using custom model: {}", custom_path);
model_path
} else {
// Use built-in model
let model_info = args.model.info();
println!("Using model: {}", model_info.name);
get_model_path(&model_info, None, args.offline)?
};
// Create ONNX Runtime session // Create ONNX Runtime session
let mut session = create_session(&model_path)?; let mut session = create_session(&model_path)?;
let input_name = session.inputs()[0].name().to_string();
// Preprocess // Preprocess image
let input_tensor = preprocess_image(&img, 1024, 1024)?; let input_tensor = preprocess_image(&img, model_info.input_size.0, model_info.input_size.1)?;
// Run inference // Run inference
println!("Running inference..."); println!("Running inference...");
let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?; let outputs = session.run(ort::inputs!["input" => Tensor::from_array(input_tensor)?])?;
// Extract mask output // Extract mask output
let mask_output = &outputs[0]; let mask_output = &outputs[0];

View file

@ -1,5 +1,3 @@
use clap::ValueEnum;
use ort::execution_providers::CUDAExecutionProvider;
use ort::session::{Session, builder::GraphOptimizationLevel}; use ort::session::{Session, builder::GraphOptimizationLevel};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use std::error::Error; use std::error::Error;
@ -7,41 +5,21 @@ use std::fs;
use std::io::Write; use std::io::Write;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
#[clap(rename_all = "kebab-case")]
pub enum Model {
BiRefNetLite,
Bria,
}
impl Model {
pub fn info(&self) -> &ModelInfo {
match self {
Model::BiRefNetLite => &ModelInfo::BIREFNET_LITE,
Model::Bria => &ModelInfo::BRIA,
}
}
}
/// Model metadata /// Model metadata
pub struct ModelInfo { pub struct ModelInfo {
pub name: &'static str, pub name: &'static str,
pub url: &'static str, pub url: &'static str,
pub sha256: Option<&'static str>, pub sha256: Option<&'static str>,
pub input_size: (u32, u32),
} }
impl ModelInfo { impl ModelInfo {
/// BiRefNet General Lite model /// BiRefNet General Lite model
pub const BIREFNET_LITE: ModelInfo = ModelInfo { pub const BIREFNET_LITE: ModelInfo = ModelInfo {
name: "birefnet-general-lite", name: "birefnet-general-lite",
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx", url: "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png",
sha256: Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333"), sha256: None, // We'll skip verification for now
}; input_size: (1024, 1024),
pub const BRIA: ModelInfo = ModelInfo {
name: "bria",
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx",
sha256: Some("5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958"),
}; };
} }
@ -178,17 +156,16 @@ pub fn get_model_path(
Ok(model_path) Ok(model_path)
} }
/// Create an ONNX Runtime session from model path with CUDA backend /// Create an ONNX Runtime session from model path
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> { pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
println!("Loading model into ONNX Runtime with CUDA backend..."); println!("Loading model into ONNX Runtime...");
let session = Session::builder()? let session = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
.with_optimization_level(GraphOptimizationLevel::Level3)? .with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)? .with_intra_threads(4)?
.commit_from_file(model_path)?; .commit_from_file(model_path)?;
println!("Model loaded successfully with CUDA!"); println!("Model loaded successfully!");
Ok(session) Ok(session)
} }