Lux.jl icon indicating copy to clipboard operation
Lux.jl copied to clipboard

Broadcast Layer

Open avik-pal opened this issue 1 year ago • 0 comments

A simple implementation:

struct BroadcastLayer{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)}
    layers::T
end

function BroadcastLayer(layers...)
    for l in layers
        if !iszero(statelength(l))
            throw(ArgumentError("Stateful layer `$l` are not supported for `BroadcastLayer`."))
        end
    end
    names = ntuple(i -> Symbol("layer_$i"), length(layers))
    return BroadcastLayer(NamedTuple{names}(layers))
end

BroadcastLayer(; kwargs...) = BroadcastLayer(connection, (; kwargs...))

function (m::BroadcastLayer)(x, ps, st::NamedTuple{names}) where {names}
    results = (first ∘ Lux.apply).(values(m.layers), x, values(ps), values(st))
    return results, st
end

Base.keys(m::BroadcastLayer) = Base.keys(getfield(m, :layers))

Originally posted by @avik-pal in https://github.com/LuxDL/Lux.jl/issues/282#issuecomment-1513524492

avik-pal avatar Oct 20 '23 16:10 avik-pal