Input Wrapper
This is a draft, closes #878.
Note: If this design ends up being useful, this could be implemented as a separated library (there are only code additions and they don't conflict with anything), but for perhaps feedback, it's better and more straight-forward to currently have this as a draft PR.
- Add
#[input_wrapper].- Add the heck dep to convert from CamelCase into snake_case.
- Add layers.
-
Id, which just forwards the input. -
On, applies some Module into an input wrapper field.- Contains a test demonstrating it's usage.
-
Add, which callstry_addfor the inputs.
-
This is how it gets used: https://github.com/coreylowman/dfdx/blob/52b62649d38482ceea432027cabed08ceca12f52/dfdx/src/nn/layers/on.rs#L41-L64
This is what gets generated from the above:
rust code
pub struct Split1<Forward, Skip> {
pub forward: Forward,
pub skip: Skip,
}
/// Automatically generated by `input_wrapper`. The containing items are visible on your project's documentation.
pub mod split1 {
use super::Split1;
/// Indicates the [`Split1::forward`] field. \nThis field is the `0` value (`0`-based index).
#[allow(non_camel_case_types)]
#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct forward;
/// Indicates the [`Split1::skip`] field. \nThis field is the `1` value (`0`-based index).
#[allow(non_camel_case_types)]
#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct skip;
/// Indicates a conversion from a (Forward, Skip) tuple into a `Split1<Forward, Skip>`.
#[derive(
Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule,
)]
pub struct FromTuple;
/// Indicates a conversion from a `Split1<Forward, Skip>` into a (Forward, Skip) tuple.
#[derive(
Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule,
)]
pub struct IntoTuple;
/// Conversion of a tuple into a [`Split1`].
impl<Forward, Skip> From<(Forward, Skip)> for Split1<Forward, Skip> {
fn from(x: (Forward, Skip)) -> Self {
Split1 {
forward: x.0,
skip: x.1,
}
}
}
/// Conversion of a [`Split1`] into a tuple.
impl<Forward, Skip> From<Split1<Forward, Skip>> for (Forward, Skip) {
fn from(x: Split1<Forward, Skip>) -> Self {
(x.forward, x.skip)
}
}
/// Module to convert a tuple into a [`Split1`].
impl<Forward, Skip> crate::prelude::Module<(Forward, Skip)> for FromTuple {
type Output = Split1<Forward, Skip>;
fn try_forward(&self, x: (Forward, Skip)) -> Result<Self::Output, crate::prelude::Error> {
Ok(x.into())
}
}
/// Module to convert a [`Split1`] into a tuple.
impl<Forward, Skip> crate::prelude::Module<Split1<Forward, Skip>> for IntoTuple {
type Output = (Forward, Skip);
fn try_forward(
&self,
x: Split1<Forward, Skip>,
) -> Result<Self::Output, crate::prelude::Error> {
Ok(x.into())
}
}
/// Module that access [`Split1::forward`] and then applies Module `M` on it.
impl<M: crate::prelude::Module<Forward>, Forward, Skip>
crate::prelude::Module<Split1<Forward, Skip>> for crate::prelude::On<forward, M>
{
type Output = Split1<<M as crate::prelude::Module<Forward>>::Output, Skip>;
fn try_forward(
&self,
x: Split1<Forward, Skip>,
) -> Result<Self::Output, crate::prelude::Error> {
let x0 = x.forward;
let x1 = x.skip;
let x0 = self.t.try_forward(x0)?;
let x = Split1 {
forward: x0,
skip: x1,
};
Ok(x)
}
fn try_forward_mut(
&mut self,
x: Split1<Forward, Skip>,
) -> Result<Self::Output, crate::prelude::Error> {
let x0 = x.forward;
let x1 = x.skip;
let x0 = self.t.try_forward_mut(x0)?;
let x = Split1 {
forward: x0,
skip: x1,
};
Ok(x)
}
}
/// Module that access [`Split1::skip`] and then applies Module `M` on it.
impl<M: crate::prelude::Module<Skip>, Forward, Skip>
crate::prelude::Module<Split1<Forward, Skip>> for crate::prelude::On<skip, M>
{
type Output = Split1<Forward, <M as crate::prelude::Module<Skip>>::Output>;
fn try_forward(
&self,
x: Split1<Forward, Skip>,
) -> Result<Self::Output, crate::prelude::Error> {
let x0 = x.forward;
let x1 = x.skip;
let x1 = self.t.try_forward(x1)?;
let x = Split1 {
forward: x0,
skip: x1,
};
Ok(x)
}
fn try_forward_mut(
&mut self,
x: Split1<Forward, Skip>,
) -> Result<Self::Output, crate::prelude::Error> {
let x0 = x.forward;
let x1 = x.skip;
let x1 = self.t.try_forward_mut(x1)?;
let x = Split1 {
forward: x0,
skip: x1,
};
Ok(x)
}
}
}
To add info on this, this is how I was able to define a unet:
(Note: in this case I was using a version of dfdx that had some other local changes, specially experimental ones.)
rust code
#[input_wrapper]
#[derive(Clone, Debug)]
pub struct Split<Forward, Skip> {
pub forward: Forward,
pub skip: Skip,
}
impl<Forward, Skip, const AXIS: isize> TryConcatTensorAlong<Axis<AXIS>> for Split<Forward, Skip>
where
(Forward, Skip): TryConcatTensorAlong<Axis<AXIS>>,
{
type Output = <(Forward, Skip) as TryConcatTensorAlong<Axis<AXIS>>>::Output;
fn try_concat_tensor_along(self, ax: Axis<AXIS>) -> Result<Self::Output, Error> {
let (forward, skip) = self.into();
(forward, skip).try_concat_tensor_along(ax)
}
}
/// From:
/// ```ignore
/// batch * CH_IN * height * width
/// ```
///
/// To:
/// ```ignore
/// batch * CH_OUT * height * width
/// ```
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(ConvBlock)]
pub struct ConvBlockConfig<const CH_IN: usize, const CH_OUT: usize> {
pub conv_1: Conv2DConstConfig<CH_IN, CH_OUT, 3, 1, 1>,
pub norm_1: BatchNorm2DConstConfig<CH_OUT>,
pub a_1: ReLU,
//
pub conv_2: Conv2DConstConfig<CH_OUT, CH_OUT, 3, 1, 1>,
pub norm_2: BatchNorm2DConstConfig<CH_OUT>,
pub a_2: ReLU,
}
/// From:
/// ```ignore
/// batch * CH_IN * height * width
/// ```
///
/// To:
/// ```ignore
/// Split {
/// forward: batch * CH_OUT * height/2 * width/2,
/// skip: batch * CH_OUT * height * width,
/// }
/// ```
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(DownBlock)]
pub struct DownBlockConfig<const CH_IN: usize, const CH_OUT: usize> {
pub conv: ConvBlockConfig<CH_IN, CH_OUT>,
//
pub split: SplitInto<(Id, Id)>,
pub wrapper: split::FromTuple,
pub pool: On<split::forward, MaxPool2DConst<2, 2, 0>>,
}
/// From:
/// ```ignore
/// Split {
/// forward: batch * CH_INF * height/2 * width/2,
/// skip: batch * CH_INS * height * width,
/// }
/// ```
///
/// To:
/// ```ignore
/// batch * CH_OUT * height * width
/// ```
///
/// Notes:
/// - `CH_INF` refers to the #channels from [`Split::forward`].
/// - `CH_INS` refers to the #channels from [`Split::skip`], but this parameter is not directly passed to this structure.
/// - `CH_CONCAT` is supposed to be `CH_OUT + CH_INS`.
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(UpBlock)]
pub struct UpBlockConfig<const CH_INF: usize, const CH_OUT: usize, const CH_CONCAT: usize> {
// for keras' padding='same', the PADDING value must be set to:
// ((kernel-1) * dilation + 1) // 2 = ((3-1) * 1 + 1 // 2) = 1
pub conv_trans:
On<split::forward, ConvTrans2DConstConfig<CH_INF, CH_OUT, 3, 2, 1, 1, 1, 1>>,
pub bias: On<split::forward, Bias2DConstConfig<CH_OUT>>,
// concat "skip" and "forward" along channels
pub tuple: split::IntoTuple,
pub concat: ops::ConcatTensorAlong<Axis<1>>,
pub conv: ConvBlockConfig<CH_CONCAT, CH_OUT>,
}
/// Just applies `M`.
type Onc0<M> = M;
/// Access `F` and then applies `M`.
type Onc1<F, M> = On<F, M>;
/// Access `F` consecutively 2 times and then applies `M`.
type Onc2<F, M> = On<F, On<F, M>>;
/// Access `F` consecutively 3 times and then applies `M`.
type Onc3<F, M> = On<F, Onc2<F, M>>;
/// Access `F` consecutively 4 times and then applies `M`.
type Onc4<F, M> = On<F, Onc3<F, M>>;
/// Access `F` consecutively 5 times and then applies `M`.
type Onc5<F, M> = On<F, Onc4<F, M>>;
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(Model)]
pub struct ModelConfig {
// encoder
pub down_block_0: Onc0<DownBlockConfig<3, 32>>,
pub down_block_1: Onc1<split::forward, DownBlockConfig<32, 64>>,
pub down_block_2: Onc2<split::forward, DownBlockConfig<64, 128>>,
pub down_block_3: Onc3<split::forward, DownBlockConfig<128, 256>>,
// bottleneck
// note: this increases channels but does not reduces the height nor the width
pub conv_bottle: Onc4<split::forward, ConvBlockConfig<256, 512>>,
// decoder
pub up_block_4: Onc3<split::forward, UpBlockConfig<512, 256, 512>>,
pub up_block_3: Onc2<split::forward, UpBlockConfig<256, 128, 256>>,
pub up_block_2: Onc1<split::forward, UpBlockConfig<128, 64, 128>>,
pub up_block_1: Onc0<UpBlockConfig<64, 32, 64>>,
// yclass channel conversion
pub conv_2: Conv2DConstConfig<32, 13, 1, 1, 0>,
pub bias_2: Bias2DConstConfig<13>,
}
And this is how I defined and trained a simple RNN (based on this exercise). The trained model was able to generate dino-like names.
rust code
pub mod model {
use super::*;
#[input_wrapper]
#[derive(Clone, Debug)]
pub struct Input<A, X> {
pub a_prev: A,
pub x: X,
}
impl<A, X, const AXIS: isize> TryConcatTensorAlong<Axis<AXIS>> for Input<A, X>
where
(A, X): TryConcatTensorAlong<Axis<AXIS>>,
{
type Output = <(A, X) as TryConcatTensorAlong<Axis<AXIS>>>::Output;
fn try_concat_tensor_along(self, ax: Axis<AXIS>) -> Result<Self::Output, Error> {
(self.a_prev, self.x).try_concat_tensor_along(ax)
}
}
#[input_wrapper]
#[derive(Clone, Debug)]
pub struct Output<A, Y> {
pub a: A,
pub y: Y,
}
impl<AS: Shape, YS: Shape, E: Dtype, D: Device<E>>
Output<Tensor<AS, E, D, OwnedTape<E, D>>, Tensor<YS, E, D, OwnedTape<E, D>>>
{
pub fn merge_tapes_on_y(self) -> Self {
let (a, at) = self.a.split_tape();
let (y, yt) = self.y.split_tape();
Self {
a: a.leaky_traced(),
y: y.put_tape(at.merge(yt)),
}
}
}
//
/// Input:
/// ```ignore
/// Input {
/// a_prev: A,
/// x: X,
/// }
/// ```
///
/// Output:
/// ```ignore
/// Output {
/// a: A,
/// y: Y,
/// }
/// ```
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(Cell)]
pub struct CellConstConfig<
const NA: usize,
const NX: usize,
const NY: usize,
const CONCAT_AXIS: isize,
const NAPNX: usize = { NA + NX },
> {
// doing concat(a_prev, x) dot concat(wa^t, wb^t) + b is the same as
// doing a_prev dot wa^t + x dot wb^t + b
//
// pub amul: On<input::a_prev, MatMulConstConfig<NA, NA>>,
// pub xmul: On<input::x, MatMulConstConfig<NX, NA>>,
// pub ax_tuple: input::IntoTuple,
// pub ax_add: ops::Add,
// pub bias: Bias1DConstConfig<NA>,
//
pub concat_input: ops::ConcatTensorAlong<Axis<CONCAT_AXIS>>,
pub ax_linear: LinearConstConfig<NAPNX, NA>,
//
pub g1: Tanh,
pub ay_tuple: SplitInto<(Id, Id)>,
pub ay: output::FromTuple,
pub ylinear: On<output::y, LinearConstConfig<NA, NY>>,
}
}
#[test]
fn test_rnn() -> anyhow::Result<()> {
let dev = Cuda::try_build(0, 0)?;
const T_: usize = 2;
const BATCH: usize = 3;
// for unbatched (1D) tensors, the concat axis is 0
// for batched (2D) tensors, the concat axis is 1
const CONCAT_AXIS: isize = 1;
const NA: usize = 2;
const NX: usize = 3;
const NY: usize = 3;
type XT<T = NoneTape> = Tensor<Rank2<BATCH, NX>, f32, Device_, T>;
type AT<T = NoneTape> = Tensor<Rank2<BATCH, NA>, f32, Device_, T>;
type YT<T = NoneTape> = Tensor<Rank2<BATCH, NY>, f32, Device_, T>;
let mut model =
dev.build_module::<f32>(model::CellConstConfig::<NA, NX, NY, CONCAT_AXIS>::default());
let mut grads = model.alloc_grads();
let mut opt = dfdx::prelude::optim::Adam::new(
&model,
AdamConfig {
lr: 1e-4,
..Default::default() // weight_decay: Some(dfdx::nn::optim::WeightDecay::L2(0.001)),
},
);
const EPOCHS: usize = 2;
for e in 0..EPOCHS {
let a_prev: AT = dev.zeros();
let mut x: XT = dev.zeros();
let mut a_prev_t: AT<_> = a_prev.leaky_traced();
let mut batch_loss = 0f32;
for _t in 0..T_ {
let y_t: YT = dev.sample_uniform();
let x_t: XT<OwnedTape<f32, Device_>> = x.traced(grads);
let input = model::Input {
a_prev: a_prev_t,
x: x_t,
};
let prediction = model.forward_mut(input);
let prediction = prediction.merge_tapes_on_y();
let loss_t =
dfdx::losses::cross_entropy_with_logits_loss(prediction.y, y_t.clone());
batch_loss += loss_t.array();
// Note:
// Running backprop and model update for each timestep t.
// A different approach would be to run backprop at the last timestep and update once.
// Or yet do something in between.
grads = loss_t.backward();
opt.update(&mut model, &grads).unwrap();
x = y_t;
a_prev_t = prediction.a;
}
println!("epoch: {}; loss: {}", e, batch_loss);
// grads.drop_non_leafs();
model.zero_grads(&mut grads);
}
Ok(())
}
I'll prioritize moving this experiment to a separate crate, but feel free to ping in case anyone have some question or suggestion.