Compare commits

...

10 commits

Author SHA1 Message Date
Matthew Deville
b979fadb43 wip 2026-02-01 18:42:22 +01:00
Matthew Deville
1a338de3c0 better tooling setup 2026-02-01 17:26:16 +01:00
Matthew Deville
596cd28dcc typed models 2026-01-23 13:26:42 +01:00
Matthew Deville
6e965c4a18 wip 2026-01-22 23:39:46 +01:00
Matthew Deville
609afde727 wip 2026-01-22 23:36:35 +01:00
Matthew Deville
22c9315b96 wip 2026-01-22 23:20:38 +01:00
Matthew Deville
9e9721cbbc infer input size 2026-01-22 21:48:08 +01:00
Matthew Deville
f6961c3177 wip 2026-01-22 01:37:09 +01:00
Matthew Deville
1fae113c5c update 2026-01-22 01:28:12 +01:00
Matthew Deville
75dd44f27a direnv 2026-01-22 01:09:39 +01:00
8 changed files with 83 additions and 29 deletions

1
.envrc Normal file
View file

@ -0,0 +1 @@
use flake

1
.gitignore vendored
View file

@ -1 +1,2 @@
/target
.direnv

9
.vscode/extensions.json vendored Normal file
View file

@ -0,0 +1,9 @@
{
"recommendations": [
"fill-labs.dependi",
"jnoortheen.nix-ide",
"mkhl.direnv",
"rust-lang.rust-analyzer",
"tamasfe.even-better-toml",
]
}

View file

@ -20,11 +20,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1768886240,
"narHash": "sha256-C2TjvwYZ2VDxYWeqvvJ5XPPp6U7H66zeJlRaErJKoEM=",
"lastModified": 1769789167,
"narHash": "sha256-kKB3bqYJU5nzYeIROI82Ef9VtTbu4uA3YydSk/Bioa8=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "80e4adbcf8992d3fd27ad4964fbb84907f9478b0",
"rev": "62c8382960464ceb98ea593cb8321a2cf8f9e3e5",
"type": "github"
},
"original": {
@ -62,11 +62,11 @@
"nixpkgs": "nixpkgs_2"
},
"locked": {
"lastModified": 1768963622,
"narHash": "sha256-n6VHiUgrYD9yjagzG6ncVVqFbVTsKCI54tR9PNAFCo0=",
"lastModified": 1769915446,
"narHash": "sha256-f1F/umtX3ZD7fF9DHSloVHc0mnAT0ry0YK2jI/6E0aI=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "2ef5b3362af585a83bafd34e7fc9b1f388c2e5e2",
"rev": "bc00300f010275e46feb3c3974df6587ff7b7808",
"type": "github"
},
"original": {

View file

@ -20,7 +20,9 @@
overlays = [ (import rust-overlay) ];
pkgs = import nixpkgs { inherit system overlays; };
stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.clangStdenv;
rustToolchain = pkgs.rust-bin.stable.latest.default;
rustToolchain = pkgs.rust-bin.stable.latest.default.override {
extensions = [ "rust-src" ];
};
nativeBuildInputs = with pkgs; [
rustToolchain

View file

@ -1,2 +0,0 @@
[toolchain]
channel = "stable"

View file

@ -9,7 +9,7 @@ mod model;
mod postprocessing;
mod preprocessing;
use model::{ModelInfo, create_session, get_model_path};
use model::{Model, create_session, get_model_path};
use postprocessing::{apply_mask, create_side_by_side};
use preprocessing::preprocess_image;
@ -20,7 +20,17 @@ struct Args {
#[arg(help = "URL of the image to download and process")]
url: String,
#[arg(long, help = "Path to custom ONNX model")]
#[arg(
short,
long,
group = "model_selection",
value_enum,
default_value_t = Model::Bria,
help = "Model to use: 'bria' or 'birefnet-lite' (mutually exclusive with --model-path)"
)]
model: Model,
#[arg(long, group = "model_selection", help = "Path to custom ONNX model")]
model_path: Option<String>,
#[arg(long, help = "Skip model download, fail if not cached")]
@ -50,25 +60,35 @@ fn main() -> Result<(), Box<dyn Error>> {
// Load the image
let img = image::load_from_memory(&bytes)?;
let (width, height) = img.dimensions();
println!("Loaded image: {}x{}\n", width, height);
let (img_width, img_height) = img.dimensions();
println!("Loaded image: {}x{}\n", img_width, img_height);
// Get model info
let model_info = ModelInfo::BIREFNET_LITE;
// Get model path - either from custom path or by selecting built-in model
let model_path = if let Some(custom_path) = args.model_path {
// Custom model path provided, use it directly
let model_path = std::path::PathBuf::from(&custom_path);
if !model_path.exists() {
return Err(format!("Custom model path does not exist: {}", custom_path).into());
}
println!("Using custom model: {}", custom_path);
model_path
} else {
// Use built-in model
let model_info = args.model.info();
println!("Using model: {}", model_info.name);
// Get or download model
let model_path = get_model_path(&model_info, args.model_path.as_deref(), args.offline)?;
get_model_path(&model_info, None, args.offline)?
};
// Create ONNX Runtime session
let mut session = create_session(&model_path)?;
let input_name = session.inputs()[0].name().to_string();
// Preprocess image
let input_tensor = preprocess_image(&img, model_info.input_size.0, model_info.input_size.1)?;
// Preprocess
let input_tensor = preprocess_image(&img, 1024, 1024)?;
// Run inference
println!("Running inference...");
let outputs = session.run(ort::inputs!["input" => Tensor::from_array(input_tensor)?])?;
let outputs = session.run(ort::inputs![input_name => Tensor::from_array(input_tensor)?])?;
// Extract mask output
let mask_output = &outputs[0];

View file

@ -1,3 +1,5 @@
use clap::ValueEnum;
use ort::execution_providers::CUDAExecutionProvider;
use ort::session::{Session, builder::GraphOptimizationLevel};
use sha2::{Digest, Sha256};
use std::error::Error;
@ -5,21 +7,41 @@ use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
#[clap(rename_all = "kebab-case")]
pub enum Model {
BiRefNetLite,
Bria,
}
impl Model {
pub fn info(&self) -> &ModelInfo {
match self {
Model::BiRefNetLite => &ModelInfo::BIREFNET_LITE,
Model::Bria => &ModelInfo::BRIA,
}
}
}
/// Model metadata
pub struct ModelInfo {
pub name: &'static str,
pub url: &'static str,
pub sha256: Option<&'static str>,
pub input_size: (u32, u32),
}
impl ModelInfo {
/// BiRefNet General Lite model
pub const BIREFNET_LITE: ModelInfo = ModelInfo {
name: "birefnet-general-lite",
url: "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png",
sha256: None, // We'll skip verification for now
input_size: (1024, 1024),
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
sha256: Some("5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333"),
};
pub const BRIA: ModelInfo = ModelInfo {
name: "bria",
url: "https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx",
sha256: Some("5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958"),
};
}
@ -156,16 +178,17 @@ pub fn get_model_path(
Ok(model_path)
}
/// Create an ONNX Runtime session from model path
/// Create an ONNX Runtime session from model path with CUDA backend
pub fn create_session(model_path: &Path) -> Result<Session, Box<dyn Error>> {
println!("Loading model into ONNX Runtime...");
println!("Loading model into ONNX Runtime with CUDA backend...");
let session = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?
.commit_from_file(model_path)?;
println!("Model loaded successfully!");
println!("Model loaded successfully with CUDA!");
Ok(session)
}