rai
rai copied to clipboard
RAI: Rust ML framework with composable transformations like JAX.
RAI
ML framework with ergonomic APIs in Rust. Lazy computation and composable transformations like JAX.
Installation
cargo add rai
Code snippets
Function transformations (jvp, vjp, grad, value_and_grad)
use rai::{grad, Cpu, Tensor, F32};
fn f(x: &Tensor) -> Tensor {
x.sin()
}
fn main() {
let grad_fn = grad(grad(f));
let x = &Tensor::ones([1], F32, &Cpu);
let grad = grad_fn(x);
println!("{}", grad.dot_graph());
println!("{}", grad);
}
NN Modules, Optimizer and loss functions
fn loss_fn<M: TrainableModule<Input = Tensor, Output = Tensor>>(
model: &M,
input: &Tensor,
labels: &Tensor,
) -> (Tensor, Aux<Tensor>) {
let logits = model.forward(input);
let loss = softmax_cross_entropy(&logits, labels).mean(..);
(loss, Aux(logits))
}
fn train_step<M: TrainableModule<Input = Tensor, Output = Tensor>, O: Optimizer>(
optimizer: &mut O,
model: &M,
input: &Tensor,
labels: &Tensor,
) {
let vg_fn = value_and_grad(loss_fn);
let ((_loss, Aux(_logits)), (grads, ..)) = vg_fn((model, input, labels));
let mut params = optimizer.step(&grads);
eval(¶ms);
model.update_params(&mut params);
}
Examples
- linear_regression
cargo run --bin linear_regression --release
- mnist
cargo run --bin mnist --releasecargo run --bin mnist --release --features=cuda
- mnist-cnn
cargo run --bin mnist-cnn --releasecargo run --bin mnist-cnn --release --features=cuda
- phi2
cargo run --bin phi2 --releasecargo run --bin phi2 --release --features=cuda
- phi3
cargo run --bin phi3 --releasecargo run --bin phi3 --release --features=cuda
- qwen2
cargo run --bin qwen2 --releasecargo run --bin qwen2 --release --features=cuda
- gemma
- accept license agreement in https://huggingface.co/google/gemma-2b
pip install huggingface_hub- login to hf
huggingface-cli login cargo run --bin gemma --releasecargo run --bin gemma --release --features=cuda
- vit
cargo run --bin vit --releasecargo run --bin vit --release --features=cuda
LICENSE
This project is licensed under either of
- Apache License, Version 2.0, (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
- MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT)
at your option.