Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

RFC: do not include type params of `Chain` children in compact `show`

Open ToucheSir opened this issue 3 years ago • 2 comments

When looking at stacktraces with Flux models, it's not uncommon to see types that cover a screen's worth of space or more. This is especially relevant with large, nested models such as those found in Metalhead. One approach that has been tried to reduce the amount of printing noise is https://github.com/FluxML/Zygote.jl/blob/master/src/compiler/show.jl, which purposefully leaves out certain type params. I think it would be worthwhile exploring a similar approach for Flux container layers. We could turn something like this:

Flux.Chain{Tuple{
  Flux.Chain{Tuple{
    Flux.Chain{Tuple{
      Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}},
      typeof(Metalhead.Layers._flatten_spatial),
      typeof(identity)
    }},
    Metalhead.Layers.ClassTokens{Array{Float32, 3}},
    Metalhead.Layers.ViPosEmbedding{Matrix{Float32}},
    Flux.Dropout{Float64, Colon, Random.TaskLocalRNG}, 
    Flux.Chain{Vector{Flux.Chain{Tuple{
      Flux.SkipConnection{Flux.Chain{Tuple{
        Flux.LayerNorm{typeof(identity), Flux.Scale{typeof(identity), Vector{Float32}, Vector{Float32}}, Float32, 1}, 
        Metalhead.Layers.MHAttention{
          Flux.Dense{typeof(identity), Matrix{Float32}, Bool}, Flux.Dropout{Float64, Colon, Random.TaskLocalRNG},
          Flux.Chain{Tuple{
            Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}},
            Flux.Dropout{Float64, Colon, Random.TaskLocalRNG}
          }}
        }
      }}, typeof(+)},
      Flux.SkipConnection{Flux.Chain{Tuple{
        Flux.LayerNorm{typeof(identity), Flux.Scale{typeof(identity), Vector{Float32}, Vector{Float32}}, Float32, 1},
        Flux.Chain{Tuple{
          Flux.Dense{typeof(NNlib.gelu), Matrix{Float32}, Vector{Float32}},
          Flux.Dropout{Float64, Colon, Random.TaskLocalRNG},
          Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}},
          Flux.Dropout{Float64, Colon, Random.TaskLocalRNG}
        }}
      }}, typeof(+)}
    }}}},
    Metalhead.var"#120#121"
  }},
  Flux.Chain{Tuple{
    Flux.LayerNorm{typeof(identity), Flux.Scale{typeof(identity), Vector{Float32}, Vector{Float32}}, Float32, 1},
    Flux.Dense{typeof(NNlib.tanh_fast), Matrix{Float32}, Vector{Float32}}
  }}
}}

(note, manually formatted!)

To one of these (other combinations of braces and ellipses welcome):

Flux.Chain(Flux.Chain, Flux.Chain)
Flux.Chain(Flux.Chain(..), Flux.Chain(..))
Flux.Chain(Flux.Chain{..}, Flux.Chain{..})

Thoughts?

ToucheSir avatar Jun 15 '22 02:06 ToucheSir

Could this be somehow extended to gradient errors as well? Zygote spits out enormous stack traces, and by the time I find the error it's buried under an absolute mountain of repetitive mentions to the same lines of code (which in the case of Metalhead is often like a screen's worth itself, given nested Chains and the like), not to mention Zygote errors are pretty cryptic as is, so every bit helps

theabhirath avatar Jun 15 '22 13:06 theabhirath

That's part of what the linked file does, actually. Trimming type parameters from gradients is a little tricker since Zygote likes to use plain ol' Julia types, but we could look into better displays for self-hosted ones like Jnew. The line number references to the same line are unfortunately distinct functions, but it's hard to make them more specific since they come from generated functions which don't map to any existing source code (being created mostly out of thin air during the AD transform).

ToucheSir avatar Jun 16 '22 00:06 ToucheSir