Enzyme-JAX
Enzyme-JAX copied to clipboard
``` %10899 = "enzymexla.comm_region"() ({ %12500 = stablehlo.slice %5743 [0:1, 0:8, 0:80] {sdy.sharding = #sdy.sharding_per_value} : (tensor) -> tensor %12501 = stablehlo.slice %5743 [0:2, 0:8, 0:80] {sdy.sharding = #sdy.sharding_per_value} :...
On 633c33a97a3a625804876f1dc40a252e5b88ba4a, the module https://gist.github.com/giordano/debe26b94593d8f228135aab39d24b91 results in ``` loc(callsite(fused["/home/giordano/.julia/packages/LLVM/b3kFs/src/interop/base.jl":39:0] at callsite(fused["none":0:0] at callsite(fused["none":0:0] at callsite(fused["/home/giordano/.julia/packages/LLVM/b3kFs/src/interop/pointer.jl":88:0] at callsite(fused["/home/giordano/.julia/dev/Reactant/ext/ReactantCUDAExt.jl":271:0] at callsite(fused["/home/giordano/.julia/dev/Reactant/ext/ReactantCUDAExt.jl":264:0] at callsite(fused["/home/giordano/.julia/dev/Reactant/ext/ReactantCUDAExt.jl":306:0] at callsite(fused["/home/giordano/.julia/dev/Reactant/ext/ReactantCUDAExt.jl":318:0] at callsite(fused["/home/giordano/.julia/packages/OffsetArrays/yHW0g/src/OffsetArrays.jl":442:0] at callsite(fused["/home/giordano/.julia/packages/KernelAbstractions/sWSE0/src/macros.jl":322:0] at fused["none":0:0]))))))))))): error:...
WIP, looking at whether any changes are needed to the padding part of the code
See the while dus pass
If you have an op with multiple slices which have rotate users we should make one rotate on a larger slice thereof, and slice into the rotate (essentially we want...
``` %c_280 = stablehlo.constant dense : tensor %c_277 = stablehlo.constant dense : tensor %c_281 = stablehlo.constant dense : tensor %35 = "enzymexla.comm_region"() ({ %11832 = stablehlo.slice %arg23 [8:9, 8:16, 8:88]...
``` %c_118 = stablehlo.constant dense : tensor %5329 = stablehlo.select %c_118, %5324, %5328 : tensor, tensor ``` ``` %c_121 = stablehlo.constant dense : tensor %5278 = stablehlo.select %c_121, %5277, %cst_122...
``` %16123 = stablehlo.slice %arg31 [0:8, 8:504, 8:248] : (tensor) -> tensor loc(#loc2609) %16118 = stablehlo.negate %16117 : tensor loc(#loc1468) %16124 = stablehlo.slice %arg31 [520:528, 8:504, 8:248] : (tensor) ->...