Yota.jl icon indicating copy to clipboard operation
Yota.jl copied to clipboard

Support array comprehension

Open dfdx opened this issue 3 years ago • 1 comments

Support simple forms of array comprehension, throw a meaningful error for more complex forms.

dfdx avatar Feb 21 '21 16:02 dfdx

Array comprehension can now be traced, but differentiation fails due to missing derivatives:

julia> trace(xs -> [x + 1 for x in xs], rand(3))
([1.7233140265316134, 1.1141674288968166, 1.592317598761971], Tape{Dict{Any, Any}}
  inp %1::var"#59#61"
  inp %2::Vector{Float64}
  %3 = __new__(var"#60#62")::var"#60#62"
  %4 = apply_type(Base.Generator, Vector{Float64}, var"#60#62")::DataType
  %5 = apply_type(Base.Generator, Vector{Float64}, var"#60#62")::DataType
  %6 = convert(var"#60#62", %3)::var"#60#62"
  %7 = convert(Vector{Float64}, %2)::Vector{Float64}
  %8 = __new__(%5, %6, %7)::Base.Generator{Vector{Float64}, var"#60#62"}
  %9 = collect(%8)::Vector{Float64}
)

julia> grad(xs -> sum([x + 1 for x in xs]), rand(3))
ERROR: No deriative rule found for op %9 = collect(%8)::Vector{Float64}, try defining it using ChainRules.rrule(::typeof(collect), ::Base.Generator{Vector{Float64}, var"#64#66"}) = ...
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] step_back!(tape::Tape{GradCtx}, y::Variable, deriv_todo::Vector{Variable})
   @ Main ~/work/Yota/src/grad.jl:128
 [3] back!(tape::Tape{GradCtx}; seed::Int64)
   @ Main ~/work/Yota/src/grad.jl:176
 [4] gradtape!(tape::Tape{GradCtx}; seed::Int64)
   @ Main ~/work/Yota/src/grad.jl:197
 [5] gradtape(f::var"#63#65", args::Vector{Float64}; seed::Int64)
   @ Main ~/work/Yota/src/grad.jl:210
 [6] grad(f::var"#63#65", args::Vector{Float64}; seed::Int64)
   @ Main ~/work/Yota/src/grad.jl:275
 [7] grad(f::var"#63#65", args::Vector{Float64})
   @ Main ~/work/Yota/src/grad.jl:270
 [8] top-level scope
   @ REPL[5]:1

dfdx avatar Jul 03 '21 20:07 dfdx