Lux.jl
Lux.jl copied to clipboard
Broadcast Layer
A simple implementation:
struct BroadcastLayer{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)}
layers::T
end
function BroadcastLayer(layers...)
for l in layers
if !iszero(statelength(l))
throw(ArgumentError("Stateful layer `$l` are not supported for `BroadcastLayer`."))
end
end
names = ntuple(i -> Symbol("layer_$i"), length(layers))
return BroadcastLayer(NamedTuple{names}(layers))
end
BroadcastLayer(; kwargs...) = BroadcastLayer(connection, (; kwargs...))
function (m::BroadcastLayer)(x, ps, st::NamedTuple{names}) where {names}
results = (first ∘ Lux.apply).(values(m.layers), x, values(ps), values(st))
return results, st
end
Base.keys(m::BroadcastLayer) = Base.keys(getfield(m, :layers))
Originally posted by @avik-pal in https://github.com/LuxDL/Lux.jl/issues/282#issuecomment-1513524492