ChainPlots.jl icon indicating copy to clipboard operation
ChainPlots.jl copied to clipboard

Can this work with custom layers?

Open mashu opened this issue 1 year ago • 3 comments

Hi,

If I have custom layers, that are based on Flux layers, can this be used or do I need to chain everything?

mashu avatar Aug 22 '23 11:08 mashu

If your custom layer is a subtype of a Flux layer, then it might work directly, depending on how it is implemented. If it is not a subtype, then you will need to Chain the layer, even if it is just a single layer, or extend some methods to your type.

Let me know if you have any trouble. In that case, posting an MWE (minimal working example) here should help me sort it out.

rmsrosa avatar Aug 22 '23 12:08 rmsrosa

Since this is open, I guess the problem I am facing falls into this category. I have created a Concat struct to create Dense layers "on top" of one another, then merge them back with a Dense layer. This is inspired by this answer. The model can be applied to a correctly shaped input, but cannot be visualized using the same input as second parameter of the plot function.

MWE:


using Flux, ChainPlots, NNlib, Plots
# To concatenate layers
struct Concat{T}
    catted::T
end
Concat(xs...) = Concat(xs)

Flux.@functor Concat

function (C::Concat)(x)
    mapreduce((f, x) -> f(x), vcat, C.catted, x)
end

wdt = 16 # width of hidden layers
ϕ = Chain(
        Concat(
            Dense(1 => wdt, swish),
            Dense(1 => wdt, swish)
        ),
        Dense(2 * wdt => wdt, swish),
        Dense(wdt => 1)
)
input = [ones(1),ones(1)] .|> Vector{Float32}
ϕ(input) 
# 1-element Vector{Float32}:
# 0.05928603

plot(ϕ, input)
ERROR: MethodError: Cannot `convert` an object of type Vector{Float32} to an object of type Float32

Closest candidates are:
  convert(::Type{T}, ::Base.TwicePrecision) where T<:Number
   @ Base twiceprecision.jl:273
  convert(::Type{T}, ::AbstractChar) where T<:Number
   @ Base char.jl:185
  convert(::Type{T}, ::CartesianIndex{1}) where T<:Number
   @ Base multidimensional.jl:127
  ...

Stacktrace:
  [1] _broadcast_getindex_evalf
    @ ./broadcast.jl:683 [inlined]
  [2] _broadcast_getindex
    @ ./broadcast.jl:666 [inlined]
  [3] getindex
    @ ./broadcast.jl:610 [inlined]
  [4] copy
    @ ./broadcast.jl:912 [inlined]
  [5] materialize
    @ ./broadcast.jl:873 [inlined]
  [6] get_dimensions(m::Chain{Tuple{Concat{Tuple{Dense{typeof(swish), Matrix{Float32}, Vector{Float32}}, Dense{typeof(swish), Matrix{Float32}, Vector{Float32}}}}, Dense{typeof(swish), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, input_data::Vector{Vector{Float32}})
    @ ChainPlots ~/.julia/packages/ChainPlots/pfCTF/src/utils.jl:30
  [7] macro expansion
    @ ~/.julia/packages/ChainPlots/pfCTF/src/plotrecipe.jl:49 [inlined]
  [8] apply_recipe(plotattributes::AbstractDict{Symbol, Any}, m::Chain, input_data::Union{Nothing, Array})
    @ ChainPlots ~/.julia/packages/RecipesBase/BRe07/src/RecipesBase.jl:300
  [9] _process_userrecipes!(plt::Any, plotattributes::Any, args::Any)
    @ RecipesPipeline ~/.julia/packages/RecipesPipeline/BGM3l/src/user_recipe.jl:38
 [10] recipe_pipeline!(plt::Any, plotattributes::Any, args::Any)
    @ RecipesPipeline ~/.julia/packages/RecipesPipeline/BGM3l/src/RecipesPipeline.jl:72
 [11] _plot!(plt::Plots.Plot, plotattributes::Any, args::Any)
    @ Plots ~/.julia/packages/Plots/sxUvK/src/plot.jl:223
 [12] #plot#188
    @ ~/.julia/packages/Plots/sxUvK/src/plot.jl:102 [inlined]
 [13] plot(::Any, ::Any)
    @ Plots ~/.julia/packages/Plots/sxUvK/src/plot.jl:93
 [14] top-level scope
    @ REPL[15]:1

rsantet avatar Oct 03 '23 15:10 rsantet

Thanks for reporting. I manage to fix the conversion, but there is another error after that, which I still need to figure out.

rmsrosa avatar Oct 22 '23 02:10 rmsrosa