Enzyme-JAX
Enzyme-JAX copied to clipboard
[All2All] Generalize triple concat
%16123 = stablehlo.slice %arg31 [0:8, 8:504, 8:248] : (tensor<528x512x256xf64>) -> tensor<8x496x240xf64> loc(#loc2609)
%16118 = stablehlo.negate %16117 : tensor<512x496x240xf64> loc(#loc1468)
%16124 = stablehlo.slice %arg31 [520:528, 8:504, 8:248] : (tensor<528x512x256xf64>) -> tensor<8x496x240xf64> loc(#loc2609)
%16125 = stablehlo.concatenate %16123, %16118, %16124, dim = 0 : (tensor<8x496x240xf64>, tensor<512x496x240xf64>, tensor<8x496x240xf64>) -> tensor<528x496x240xf64> loc(#loc2609)