Enzyme-JAX
Enzyme-JAX copied to clipboard
Push reshapes up
- [ ] reshape( slice ) -> slice ( reshape ) ~(@jumerckx, https://github.com/EnzymeAD/Enzyme-JAX/pull/581)~
%3250 = stablehlo.reshape %125 {mhlo.sharding = "{devices=[2,1,2,2]<=[2,2,2]T(1,0,2) last_tile_dim_replicate}"} : (tensor<268x2060xf64>) -> tensor<268x1x2060xf64>
%3377 = stablehlo.slice %3250 [6:262, 0:1, 5:2053] {mhlo.sharding = "{devices=[2,1,2,2]<=[2,2,2]T(1,0,2) last_tile_dim_replicate}"} : (tensor<268x1x2060xf64>) -> tensor<256x1x2048xf64>
%3794 = stablehlo.reshape %3377 {mhlo.sharding = "{devices=[1,2,2,2]<=[2,2,2]T(1,0,2) last_tile_dim_replicate}"} : (tensor<256x1x2048xf64>) -> tensor<1x256x2048xf64>
- [x] reshape ( elementwise ) -> elementwise (reshape)
- [x] reshape ( concat ) -> concat (reshape) @chelini (https://github.com/EnzymeAD/Enzyme-JAX/pull/577/)
we just need to do where reshape adds a singleton dim