Metalhead.jl icon indicating copy to clipboard operation
Metalhead.jl copied to clipboard

Fix UNet implementation with arbitrary channel sizes (#243)

Open vinayakjeet opened this issue 1 year ago • 2 comments
trafficstars

#243

Bug Description: The current UNet implementation in the Metalhead package has a limitation where it only works with input tensors of channel size 3. This restriction causes compatibility issues when users try to use UNet with input tensors of different channel sizes.

Patch Description: To address this limitation, I've modified the UNet implementation to support input tensors with arbitrary channel sizes. The UNet model can now handle input with varying dimensions

Test Case: using Metalhead UNet((128,128),1,3,Metalhead.backbone(DenseNet(121)))

This UNet model can process without any errors

vinayakjeet avatar Mar 22 '24 14:03 vinayakjeet

Hi Vinayakjeet, thanks for the PR! Unfortunately, I don't think this does what we want yet. The problem is that inchannels isn't being passed to the model backbone. What you've done is try and change the input being passed in to the Flux.outputsize function, which actually causes an error when I try to initialise the model:

julia> using Metalhead

julia> model = UNet((128,128),1,3,Metalhead.backbone(DenseNet(121)))
ERROR: DimensionMismatch: layer Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false) expects size(input, 3) == 3, but got 128×128×1×1 Array{Flux.NilNumber.Nil, 4}
Stacktrace:
  [1] _size_check(layer::Flux.Conv{2, 2, typeof(identity), Array{…}, Bool}, x::Array{Flux.NilNumber.Nil, 4}, ::Pair{Int64, Int64})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:195
  [2] (::Flux.Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool})(x::Array{Flux.NilNumber.Nil, 4})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/layers/conv.jl:198
  [3] #outputsize#340
    @ ~/.julia/packages/Flux/jgpVj/src/outputsize.jl:93 [inlined]
  [4] outputsize(m::Flux.Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool}, inputsizes::NTuple{4, Int64})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/outputsize.jl:91
  [5] unetlayers(layers::Vector{…}, sz::NTuple{…}; outplanes::Nothing, skip_upscale::Int64, m_middle::typeof(Metalhead.unet_middle_block))
    @ Metalhead ~/Code/Metalhead.jl/src/convnets/unet.jl:34
  [6] unet(encoder_backbone::Flux.Chain{…}, imgdims::Tuple{…}, inchannels::Int64, outplanes::Int64, final::typeof(Metalhead.unet_final_block), fdownscale::Int64)
    @ Metalhead ~/Code/Metalhead.jl/src/convnets/unet.jl:81
  [7] unet
    @ ~/Code/Metalhead.jl/src/convnets/unet.jl:76 [inlined]
  [8] #UNet#175
    @ ~/Code/Metalhead.jl/src/convnets/unet.jl:120 [inlined]
  [9] UNet(imsize::Tuple{Int64, Int64}, inchannels::Int64, outplanes::Int64, encoder_backbone::Flux.Chain{Tuple{…}})
    @ Metalhead ~/Code/Metalhead.jl/src/convnets/unet.jl:118
 [10] top-level scope
    @ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types.

I would suggest that you try and rewrite the function in such a way that inchannels is passed along to the encoder backbone.

theabhirath avatar Mar 22 '24 16:03 theabhirath

A beginner contributor to the codebase, can you review the logic I have implemented, additionally I have encountered an error MethodError indicating a mismatch in method signatures for the unet function. It appears that there might be an issue with how the encoder_backbone is instantiated or utilized within the unet function. Could you please review the instantiation and usage of the encoder_backbone

vinayakjeet avatar Mar 24 '24 04:03 vinayakjeet