Compare commits
No commits in common. "b979fadb434937de36a2ad184ff45118b9fc1bad" and "86e0dd734de1be6b2c4c5f65864a246206f67066" have entirely different histories.
b979fadb43
...
86e0dd734d
8 changed files with 29 additions and 83 deletions
1
.envrc
1
.envrc
|
|
@ -1 +0,0 @@
|
|||
use flake
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -1,2 +1 @@
|
|||
/target
|
||||
.direnv
|
||||
|
|
|
|||
9
.vscode/extensions.json
vendored
9
.vscode/extensions.json
vendored
|
|
@ -1,9 +0,0 @@
|
|||
{
|
||||
"recommendations": [
|
||||
"fill-labs.dependi",
|
||||
"jnoortheen.nix-ide",
|
||||
"mkhl.direnv",
|
||||
"rust-lang.rust-analyzer",
|
||||
"tamasfe.even-better-toml",
|
||||
]
|
||||
}
|
||||
12
flake.lock
12
flake.lock
|
|
@ -20,11 +20,11 @@
|
|||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1769789167,
|
||||
"narHash": "sha256-kKB3bqYJU5nzYeIROI82Ef9VtTbu4uA3YydSk/Bioa8=",
|
||||
"lastModified": 1768886240,
|
||||
"narHash": "sha256-C2TjvwYZ2VDxYWeqvvJ5XPPp6U7H66zeJlRaErJKoEM=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "62c8382960464ceb98ea593cb8321a2cf8f9e3e5",
|
||||
"rev": "80e4adbcf8992d3fd27ad4964fbb84907f9478b0",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
@ -62,11 +62,11 @@
|
|||
"nixpkgs": "nixpkgs_2"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1769915446,
|
||||
"narHash": "sha256-f1F/umtX3ZD7fF9DHSloVHc0mnAT0ry0YK2jI/6E0aI=",
|
||||
"lastModified": 1768963622,
|
||||
"narHash": "sha256-n6VHiUgrYD9yjagzG6ncVVqFbVTsKCI54tR9PNAFCo0=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "bc00300f010275e46feb3c3974df6587ff7b7808",
|
||||
"rev": "2ef5b3362af585a83bafd34e7fc9b1f388c2e5e2",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
|
|||
|
|
@ -20,9 +20,7 @@
|
|||
overlays = [ (import rust-overlay) ];
|
||||
pkgs = import nixpkgs { inherit system overlays; };
|
||||
stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv;
|
||||
rustToolchain = pkgs.rust-bin.stable.latest.default.override {
|
||||
extensions = [ "rust-src" ];
|
||||
};
|
||||
rustToolchain = pkgs.rust-bin.stable.latest.default;
|
||||
|
||||
nativeBuildInputs = with pkgs; [
|
||||
rustToolchain
|
||||
|
|
|
|||
2
rust-toolchain.toml
Normal file
2
rust-toolchain.toml
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
[toolchain]
|
||||
channel = "stable"
|
||||
46
src/main.rs
46
src/main.rs
|
|
@ -9,7 +9,7 @@ mod model;
|
|||
mod postprocessing;
|
||||
mod preprocessing;
|
||||
|
||||
use model::{Model, create_session, get_model_path};
|
||||
use model::{ModelInfo, create_session, get_model_path};
|
||||
use postprocessing::{apply_mask, create_side_by_side};
|
||||
use preprocessing::preprocess_image;
|
||||
|
||||
|
|
@ -20,17 +20,7 @@ struct Args {
|
|||
#[arg(help = "URL of the image to download and process")]
|
||||
url: String,
|
||||
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
group = "model_selection",
|
||||
value_enum,
|
||||
default_value_t = Model::Bria,
|
||||
help = "Model to use: 'bria' or 'birefnet-lite' (mutually exclusive with --model-path)"
|
||||
)]
|
||||
model: Model,
|
||||
|
||||
#[arg(long, group = "model_selection", help = "Path to custom ONNX model")]
|
||||
#[arg(long, help = "Path to custom ONNX model")]
|
||||
model_path: Option<String>,
|
||||
|
||||
#[arg(long, help = "Skip model download, fail if not cached")]
|
||||
|
|
@ -60,35 +50,25 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||
|
||||
// Load the image
|
||||
let img = image::load_from_memory(&bytes)?;
|
||||
let (img_width, img_height) = img.dimensions();
|
||||
println!("Loaded image: {}x{}\n", img_width, img_height);
|
||||
let (width, height) = img.dimensions();
|
||||
println!("Loaded image: {}x{}\n", width, height);
|
||||
|
||||
// Get model path - either from custom path or by selecting built-in model
|
||||
let model_path = if let Some(custom_path) = args.model_path {
|
||||
// Custom model path provided, use it directly
|
||||
let model_path = std::path::PathBuf::from(&custom_path);
|
||||
if !model_path.exists() {
|
||||
return Err(format!("Custom model path does not exist: {}", custom_path).into());
|
||||
}
|
||||
println!("Using custom model: {}", custom_path);
|
||||
model_path
|
||||
} else {
|
||||
// Use built-in model
|
||||
let model_info = args.model.info();
|
||||
println!("Using model: {}", model_info.name);
|
||||
get_model_path(&model_info, None, args.offline)?
|
||||
};
|
||||
// Get model info
|
||||
let model_info = ModelInfo::BIREFNET_LITE;
|
||||
println!("Using model: {}", model_info.name);
|
||||
|
||||
// Get or download model
|
||||
let model_path = get_model_path(&model_info, args.model_path.as_deref(), args.offline)?;
|
||||
|
||||
// Create ONNX Runtime session
|
||||
let mut session = create_session(&model_path)?;
|
||||
let input_name = session.inputs()[0].name().to_string();
|
||||
|
||||
// Preprocess
|
||||
let input_tensor = preprocess_image(&img, 1024, 1024)?;
|
||||
// Preprocess image
|
||||
let input_tensor = preprocess_image(&img, model_info.input_size.0, model_info.input_size.1)?;
|
||||
|
||||
// Run inference
|
||||
println!("Running inference...");
|
||||
let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?;
|
||||
let outputs = session.run(ort::inputs!["input" => Tensor::from_array(input_tensor)?])?;
|
||||
|
||||
// Extract mask output
|
||||
let mask_output = &outputs[0];
|
||||
|
|
|
|||
37
src/model.rs
37
src/model.rs
|
|
@ -1,5 +1,3 @@
|
|||
use clap::ValueEnum;
|
||||
use ort::execution_providers::CUDAExecutionProvider;
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::error::Error;
|
||||
|
|
@ -7,41 +5,21 @@ use std::fs;
|
|||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
|
||||
#[clap(rename_all = "kebab-case")]
|
||||
pub enum Model {
|
||||
BiRefNetLite,
|
||||
Bria,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn info(&self) -> &ModelInfo {
|
||||
match self {
|
||||
Model::BiRefNetLite => &ModelInfo::BIREFNET_LITE,
|
||||
Model::Bria => &ModelInfo::BRIA,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Model metadata
|
||||
pub struct ModelInfo {
|
||||
pub name: &'static str,
|
||||
pub url: &'static str,
|
||||
pub sha256: Option<&'static str>,
|
||||
pub input_size: (u32, u32),
|
||||
}
|
||||
|
||||
impl ModelInfo {
|
||||
/// BiRefNet General Lite model
|
||||
pub const BIREFNET_LITE: ModelInfo = ModelInfo {
|
||||
name: "birefnet-general-lite",
|
||||
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
|
||||
sha256: Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333"),
|
||||
};
|
||||
|
||||
pub const BRIA: ModelInfo = ModelInfo {
|
||||
name: "bria",
|
||||
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx",
|
||||
sha256: Some("5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958"),
|
||||
url: "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png",
|
||||
sha256: None, // We'll skip verification for now
|
||||
input_size: (1024, 1024),
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -178,17 +156,16 @@ pub fn get_model_path(
|
|||
Ok(model_path)
|
||||
}
|
||||
|
||||
/// Create an ONNX Runtime session from model path with CUDA backend
|
||||
/// Create an ONNX Runtime session from model path
|
||||
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
|
||||
println!("Loading model into ONNX Runtime with CUDA backend...");
|
||||
println!("Loading model into ONNX Runtime...");
|
||||
|
||||
let session = Session::builder()?
|
||||
.with_execution_providers([CUDAExecutionProvider::default().build()])?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_intra_threads(4)?
|
||||
.commit_from_file(model_path)?;
|
||||
|
||||
println!("Model loaded successfully with CUDA!");
|
||||
println!("Model loaded successfully!");
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue