Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

fails to differentiate integers powers of Metal array

Open CarloLucibello opened this issue 1 year ago • 1 comments

I see the following error on this simple example involving Metal arrays

julia> using Metal

julia> x = Metal.ones(2)

julia> gradient(x -> sum(x.^2), x)
ERROR: InvalidIRError: compiling MethodInstance for (::Metal.var"#broadcast_cartesian_static#213")(::MtlDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}, ::Metal.StaticCartesianIndices{…}) resulted in invalid LLVM IR
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] #power_by_squaring#526
   @ ./intfuncs.jl:0
 [2] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] #power_by_squaring#526
   @ ./intfuncs.jl:0
 [2] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] *
   @ ./float.jl:493
 [2] #power_by_squaring#526
   @ ./intfuncs.jl:320
 [3] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] ==
   @ ./float.jl:616
 [2] isone
   @ ./number.jl:62
 [3] #power_by_squaring#526
   @ ./intfuncs.jl:322
 [4] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
.... repeated multiple times .......
Stacktrace:
  [1] Float32
    @ ./float.jl:338
  [2] ^
    @ ./math.jl:1231
  [3] _broadcast_getindex_evalf
    @ ./broadcast.jl:673
  [4] _broadcast_getindex
    @ ./broadcast.jl:646
  [5] _getindex
    @ ./broadcast.jl:670
  [6] _broadcast_getindex
    @ ./broadcast.jl:645
  [7] _getindex
    @ ./broadcast.jl:670
  [8] _getindex
    @ ./broadcast.jl:669
  [9] _broadcast_getindex
    @ ./broadcast.jl:645
 [10] getindex
    @ ./broadcast.jl:605
 [11] broadcast_cartesian_static
    @ ~/.julia/packages/Metal/rBb1i/src/broadcast.jl:67
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl
Stacktrace:
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, args::LLVM.Module)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/validation.jl:147
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:382 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/TimerOutputs/NRdsv/src/TimerOutput.jl:253 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:381 [inlined]
  [5] 
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:108
  [6] 
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:100
  [7] codegen
    @ ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:82 [inlined]
  [8] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:79
  [9] compile
    @ ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:74 [inlined]
 [10] (::Metal.var"#154#162"{GPUCompiler.CompilerJob{…}})(ctx::LLVM.Context)
    @ Metal ~/.julia/packages/Metal/rBb1i/src/compiler/compilation.jl:108
 [11] JuliaContext(f::Metal.var"#154#162"{GPUCompiler.CompilerJob{…}}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:34
 [12] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:25
 [13] macro expansion
    @ ~/.julia/packages/Metal/rBb1i/src/compiler/compilation.jl:107 [inlined]
 [14] macro expansion
    @ ~/.julia/packages/ObjectiveC/C7BVt/src/os.jl:264 [inlined]
 [15] compile(job::GPUCompiler.CompilerJob)
    @ Metal ~/.julia/packages/Metal/rBb1i/src/compiler/compilation.jl:105
 [16] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(Metal.compile), linker::typeof(Metal.link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/execution.jl:237
 [17] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/execution.jl:151
 [18] macro expansion
    @ ~/.julia/packages/Metal/rBb1i/src/compiler/execution.jl:189 [inlined]
 [19] macro expansion
    @ ./lock.jl:273 [inlined]
 [20] mtlfunction(f::Metal.var"#broadcast_cartesian_static#213", tt::Type{Tuple{…}}; name::Nothing, kwargs::@Kwargs{})
    @ Metal ~/.julia/packages/Metal/rBb1i/src/compiler/execution.jl:184
 [21] mtlfunction(f::Metal.var"#broadcast_cartesian_static#213", tt::Type{Tuple{…}})
    @ Metal ~/.julia/packages/Metal/rBb1i/src/compiler/execution.jl:182
 [22] macro expansion
    @ ~/.julia/packages/Metal/rBb1i/src/compiler/execution.jl:85 [inlined]
 [23] _copyto!
    @ ~/.julia/packages/Metal/rBb1i/src/broadcast.jl:74 [inlined]
 [24] copyto!
    @ ~/.julia/packages/Metal/rBb1i/src/broadcast.jl:47 [inlined]
 [25] copy
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:29 [inlined]
 [26] materialize
    @ ./broadcast.jl:867 [inlined]
 [27] (::Zygote.var"#1257#1260"{2, MtlVector{Float32, Metal.PrivateStorage}})(ȳ::MtlVector{Float32, Metal.PrivateStorage})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/lib/broadcast.jl:108
 [28] #3916#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [29] #631
    @ ./REPL[9]:1 [inlined]
 [30] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface.jl:91
 [31] gradient(f::Function, args::MtlVector{Float32, Metal.PrivateStorage})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface.jl:148
 [32] top-level scope
    @ REPL[9]:1
Some type information was truncated. Use `show(err)` to see complete types.

I get a similar error for x.^3 but not for x.^2f0. cc @maleadt

CarloLucibello avatar Oct 17 '24 05:10 CarloLucibello

I think x.^2 and x.^3 should go here... is x .^ (p - 1) producing Float64 somehow?

https://github.com/FluxML/Zygote.jl/blob/59e7ec1b32c4d4571120e143d9c40e48b370f22b/src/lib/broadcast.jl#L106-L109

mcabbott avatar Oct 18 '24 01:10 mcabbott