dfdx icon indicating copy to clipboard operation
dfdx copied to clipboard

Add transpose operation

Open coreylowman opened this issue 1 year ago • 10 comments

This would be used for matmuls & transformer implementations. ~~The hard thing about this is how to do this without moving data around a lot.~~ I believe other frameworks just change stride & dimensions without actually moving around data, which is what we can/should do as well. ~~definitely a weakness of storing actual arrays.~~

Originally posted by @jafioti in https://github.com/coreylowman/dfdx/issues/34#issuecomment-1190772585

coreylowman avatar Jul 20 '22 22:07 coreylowman

@coreylowman Do you think that at some point it will make more sense to use flat arrays for tensor storage with strides and shapes or keep it in actual arrays?

jafioti avatar Jul 21 '22 03:07 jafioti

@jafioti i was thinking about this more overnight - i think we could just do both actually (so I take back my comment in the main post). we could keep the array storage and then just convert them to slices for computations. or store them as slices and then the frontend interface could expose them as rust arrays.

coreylowman avatar Jul 21 '22 12:07 coreylowman

+1 for storing strides and dimensions alongside a flat array. This should lead to some clean API for reshaping(views) and reducing over arbitrary dimensions.

vikigenius avatar Jul 21 '22 22:07 vikigenius

I'd love to just do flat and call it a day right now, but one issue will be that we are forced to go full nightly. The size of the flat array in a Tensor3D<A, B, C> will be [f32; {A * B * C}] which requires nightly.

jafioti avatar Jul 21 '22 22:07 jafioti

We should be able to do Box slices without nightly though. all tensors would have Rc<[f32]>. Number of elements isn't captured in the type, so i'm not sure if rust would be able to auto vectorize things anymore, but won't know unless we try!

The biggest changes for this will be involved in Device traits, which are all implemented using recursive array traits

coreylowman avatar Jul 21 '22 22:07 coreylowman

@coreylowman Have you looked into how to do this? I'm going to need to transpose for my MultiHeadAttention implementation, so I could either try to do this (though I would rather keep it as a seperate PR) or just implement an inefficient copying transpose system temporarily and once this lands we can switch transpose to a more efficient impl.

jafioti avatar Jul 23 '22 04:07 jafioti

@jafioti Yeah I tried out slices and also separate transpose operation. Here are my notes:

  1. slices - this would be a huge change, and also not exactly sure how it interacts with const generics. I think we'd need some separate data storage struct that tensors would use that doesn't have the const generics. For example, if you transpose a Tensor2D<M, N>, then both the type should change to Tensor2D<N, M> and the underlying stride should change.

  2. Transpose operation - this seems way more feasible in the short term, though would do some extra memory copies (which i'm not sure how much they would actually impact things). Recommend not moving Linear to this yet because it'd have to put the weight matrix on the tape, then do transpose, then put tape back onto input so all the operations are recorded. This shouldn't cost anything runtime wise, but would be a less clear implementation.

So tl;dr: more inefficient transpose operation IMO

coreylowman avatar Jul 25 '22 13:07 coreylowman

Ok, I ended up not needing transpose for MultiHeadAttention, but I implemented a forward pass of it anyway for Tensor3D:

impl<const A: usize, const B: usize, const C: usize> Transpose<0, 1> for Tensor3D<A, B, C> {
    type Output = Tensor3D<B, A, C>;
    fn transpose(self) -> Self::Output {
        let mut new = Tensor3D::zeros();
        // Copy data
        let data = self.data();
        let new_data = new.mut_data();

        #[allow(clippy::needless_range_loop)]
        for i in 0..B {
            for j in 0..A {
                new_data[i][j] = data[j][i];
            }
        }
        new
    }
}

I agree this is atrocious, and it doens't even properly pass the grads backward. Currently it's not used for anything, so we can scrap it and try a better way.

As far as the slices change, I think it would be fine to have a separate data storage class not connected with const generics, since only we would be using this class, and once linked correctly to the tensors and their operations, we don't need to worry about it anymore. I also think it would be sort of necessary to do this for when we work on GPU kernels. I think this would also allow for cutting down on matmul functions, since no separate matmul_transpose function is required because transpose is effectively zero cost.

I really think this should be prioritized so work going forward doesn't need to be redone when the switch eventually happens.

jafioti avatar Jul 25 '22 15:07 jafioti

@jafioti @coreylowman +1 on prioritizing the slices change because this is the time to make such massive changes because once we have more users (which we will once we add GPU support) then this will be harder to do since we will have way more downstream functions that will depend on it.

Things like reduce across multiple dimensions etc. will be much easier with a slices/stride based approach.

vikigenius avatar Jul 25 '22 15:07 vikigenius

It's going to require some thinking & design work to implement stuff like reduction across multiple dimensions & permutations in a way that doesn't have a method per possible permutation/axis, and is also user friendly. Permutation for example has to update both the const generic parameters (at compile time) and also the strides (at run time).

I'm going to focus on conv nets (& hopefully help with transformers given enough time) to get more usable features first & help illuminate the path for strides rewrite. Those are two things will require a lot of additions to the internals that will help focus the rewrite. For example transformers already added a ton of functionality that wasn't necessarily obvious that we'd need.

coreylowman avatar Jul 25 '22 20:07 coreylowman

I don't know whether I should be proud that I figured out how to implement this via macros or horrified...

impl_permute!(0, 1, 3, 2);
impl_permute!(0, 2, 1, 3);
impl_permute!(0, 2, 3, 1);
impl_permute!(0, 3, 1, 2);
impl_permute!(0, 3, 2, 1);
impl_permute!(1, 0, 2, 3);
impl_permute!(1, 0, 3, 2);
impl_permute!(1, 2, 0, 3);
impl_permute!(1, 2, 3, 0);
impl_permute!(1, 3, 0, 2);
impl_permute!(1, 3, 2, 0);
impl_permute!(2, 0, 1, 3);
impl_permute!(2, 0, 3, 1);
impl_permute!(2, 1, 0, 3);
impl_permute!(2, 1, 3, 0);
impl_permute!(2, 3, 0, 1);
impl_permute!(2, 3, 1, 0);
impl_permute!(3, 0, 1, 2);
impl_permute!(3, 0, 2, 1);
impl_permute!(3, 1, 0, 2);
impl_permute!(3, 1, 2, 0);
impl_permute!(3, 2, 0, 1);
impl_permute!(3, 2, 1, 0);

coreylowman avatar Aug 22 '22 01:08 coreylowman

@coreylowman Oh ... wow. Yeah an optimal setup would be to have a single generic permute function to take in 4 const arguments, and based on that return different shapes, but I have no idea if thats even possible today (probably not). Till then it seems like we're stuck with these macros.

I would think a slightly better way to do this would have a proc macro to generate every possible permutation, so you can't miss any, but that would require a separate crate for proc macros and a big hassle so this seems good for now.

jafioti avatar Aug 22 '22 03:08 jafioti

Yeah as of now you'd need to programmatically specify the output type, which isn't possible on stable at least (maybe nightly?)

which i think may look something like this?

trait Permute3<I, J, K>: HasAxis<I> + HasAxis<J> + HasAxis<K> {
    fn permute(self) -> Tensor3D<<Self as HasAxis<I>>::SIZE, <Self as HasAxis<J>>::SIZE, <Self as HasAxis<K>>::SIZE>;
}

not sure what feature on nightly allows you to use associated consts as generic arguments.

I think the other problem (permutted index) is easier to solve

for m in 0..M {
    for n in 0..N {
        for o in 0..O {
             *permuted3_idx::<I, J, K>(output.mut_data(), [m, n, o]) = input.data()[m][n][o];
        }
    }
}

coreylowman avatar Aug 22 '22 11:08 coreylowman

Yep confirmed that with generic_const_exprs this is possible with just traits. Here's 2d version https://play.rust-lang.org/?version=nightly&mode=debug&edition=2021&gist=114492951baef2c0c68deb51a59ab401.

I'd like permute to be available on stable though. So for now i'm going to do macros, and then i'll make a follow up issue to refactor to this method once const_generic_exprs is stable

coreylowman avatar Aug 22 '22 23:08 coreylowman