ONNX.jl icon indicating copy to clipboard operation
ONNX.jl copied to clipboard

No method matching BatchNorm when loading model

Open tclements opened this issue 4 years ago • 0 comments

Trying to load ArcFace model from ONNX.

Using versions

(v1.3) pkg> st
    Status `~/.julia/environments/v1.3/Project.toml`
  [c5f51814] CUDAdrv v6.0.0
  [be33ccc6] CUDAnative v2.10.2
  [3a865a2d] CuArrays v1.7.2
  [587475ba] Flux v0.10.1
  [d0dd6a25] ONNX v0.1.1

Beginning of model.jl looks like this

using Statistics 
Mul(a,b,c) = b .* reshape(c, (1,1,size(c)[a],1)) 
Add(axis, A ,B) = A .+ reshape(B, (1,1,size(B)[1],1)) 
begin
    c_1 = BatchNorm(identity, weights["fc1_beta"], weights["fc1_gamma"], broadcast(Float32, weights["fc1_moving_mean"]), broadcast(Float32, broadcast(sqrt, broadcast(+, 2.0f-5, weights["fc1_moving_var"]))), 2.0f-5, 0.9f0, false)
    c_2 = BatchNorm(identity, weights["bn1_beta"], weights["bn1_gamma"], broadcast(Float32, weights["bn1_moving_mean"]), broadcast(Float32, broadcast(sqrt, broadcast(+, 2.0f-5, weights["bn1_moving_var"]))), 2.0f-5, 0.9f0, false)
    c_3 = BatchNorm(identity, weights["stage4_unit3_bn3_beta"], weights["stage4_unit3_bn3_gamma"], broadcast(Float32, weights["stage4_unit3_bn3_moving_mean"]), broadcast(Float32, broadcast(sqrt, broadcast(+, 2.0f-5, weights["stage4_unit3_bn3_moving_var"]))), 2.0f-5, 0.9f0, false)
    c_4 = CrossCor(weights["stage4_unit3_conv2_weight"], Float32[0.0], relu, var"stride=(1, 1)", var"pad=(1, 1, 1, 1)", var"dilation=(1, 1)")

When trying to load,

using Flux, ONNX
weights = ONNX.load_weights("weights.bson")
model = include("model.jl")

get an error due to BatchNorm

model = include("model.jl")
ERROR: LoadError: MethodError: no method matching BatchNorm(::typeof(identity), ::Base.ReinterpretArray{Float32,1,Float32,Array{Float32,1}}, ::Base.ReinterpretArray{Float32,1,Float32,Array{Float32,1}}, ::Array{Float32,1}, ::Array{Float32,1}, ::Float32, ::Float32, ::Bool)
Closest candidates are:
  BatchNorm(::F, ::V, ::V, ::W, ::W, ::N, ::N) where {F, V, W, N} at /home/timclements/.julia/packages/Flux/2i5P1/src/layers/normalise.jl:123
  BatchNorm(::Integer, ::Any; initβ, initγ, ϵ, momentum) at /home/timclements/.julia/packages/Flux/2i5P1/src/layers/normalise.jl:133

This looks very similar to #17 but I am using ONNX v0.1.1, which should have fixed this. The error is being thrown from Flux, so the ONNX version of BatchNorm with identity is not being called.

tclements avatar Feb 28 '20 03:02 tclements