ChainPlots.jl
ChainPlots.jl copied to clipboard
Can this work with custom layers?
Hi,
If I have custom layers, that are based on Flux layers, can this be used or do I need to chain everything?
If your custom layer is a subtype of a Flux layer, then it might work directly, depending on how it is implemented. If it is not a subtype, then you will need to Chain the layer, even if it is just a single layer, or extend some methods to your type.
Let me know if you have any trouble. In that case, posting an MWE (minimal working example) here should help me sort it out.
Since this is open, I guess the problem I am facing falls into this category. I have created a Concat struct to create Dense layers "on top" of one another, then merge them back with a Dense layer. This is inspired by this answer. The model can be applied to a correctly shaped input, but cannot be visualized using the same input as second parameter of the plot
function.
MWE:
using Flux, ChainPlots, NNlib, Plots
# To concatenate layers
struct Concat{T}
catted::T
end
Concat(xs...) = Concat(xs)
Flux.@functor Concat
function (C::Concat)(x)
mapreduce((f, x) -> f(x), vcat, C.catted, x)
end
wdt = 16 # width of hidden layers
ϕ = Chain(
Concat(
Dense(1 => wdt, swish),
Dense(1 => wdt, swish)
),
Dense(2 * wdt => wdt, swish),
Dense(wdt => 1)
)
input = [ones(1),ones(1)] .|> Vector{Float32}
ϕ(input)
# 1-element Vector{Float32}:
# 0.05928603
plot(ϕ, input)
ERROR: MethodError: Cannot `convert` an object of type Vector{Float32} to an object of type Float32
Closest candidates are:
convert(::Type{T}, ::Base.TwicePrecision) where T<:Number
@ Base twiceprecision.jl:273
convert(::Type{T}, ::AbstractChar) where T<:Number
@ Base char.jl:185
convert(::Type{T}, ::CartesianIndex{1}) where T<:Number
@ Base multidimensional.jl:127
...
Stacktrace:
[1] _broadcast_getindex_evalf
@ ./broadcast.jl:683 [inlined]
[2] _broadcast_getindex
@ ./broadcast.jl:666 [inlined]
[3] getindex
@ ./broadcast.jl:610 [inlined]
[4] copy
@ ./broadcast.jl:912 [inlined]
[5] materialize
@ ./broadcast.jl:873 [inlined]
[6] get_dimensions(m::Chain{Tuple{Concat{Tuple{Dense{typeof(swish), Matrix{Float32}, Vector{Float32}}, Dense{typeof(swish), Matrix{Float32}, Vector{Float32}}}}, Dense{typeof(swish), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, input_data::Vector{Vector{Float32}})
@ ChainPlots ~/.julia/packages/ChainPlots/pfCTF/src/utils.jl:30
[7] macro expansion
@ ~/.julia/packages/ChainPlots/pfCTF/src/plotrecipe.jl:49 [inlined]
[8] apply_recipe(plotattributes::AbstractDict{Symbol, Any}, m::Chain, input_data::Union{Nothing, Array})
@ ChainPlots ~/.julia/packages/RecipesBase/BRe07/src/RecipesBase.jl:300
[9] _process_userrecipes!(plt::Any, plotattributes::Any, args::Any)
@ RecipesPipeline ~/.julia/packages/RecipesPipeline/BGM3l/src/user_recipe.jl:38
[10] recipe_pipeline!(plt::Any, plotattributes::Any, args::Any)
@ RecipesPipeline ~/.julia/packages/RecipesPipeline/BGM3l/src/RecipesPipeline.jl:72
[11] _plot!(plt::Plots.Plot, plotattributes::Any, args::Any)
@ Plots ~/.julia/packages/Plots/sxUvK/src/plot.jl:223
[12] #plot#188
@ ~/.julia/packages/Plots/sxUvK/src/plot.jl:102 [inlined]
[13] plot(::Any, ::Any)
@ Plots ~/.julia/packages/Plots/sxUvK/src/plot.jl:93
[14] top-level scope
@ REPL[15]:1
Thanks for reporting. I manage to fix the conversion, but there is another error after that, which I still need to figure out.