axon
axon copied to clipboard
Dropout with rate: 1 raises on binary backend
@seanmor5 can you give me a short snippet to reproduce it or a failing test? :)
Sure, here's a failing inference:
model =
Axon.input("data")
|> Axon.dense(32)
|> Axon.relu()
|> Axon.dropout(rate: 1.0)
{init_fn, predict_fn} = Axon.build(model, mode: :train)
params = init_fn.(Nx.iota({1, 10}), %{})
predict_fn.(params, Nx.iota({1, 10}))
We get:
** (ArithmeticError) bad argument in arithmetic expression
(complex 0.4.2) lib/complex.ex:579: Complex.divide/2
(nx 0.3.0) lib/nx/binary_backend.ex:2453: Nx.BinaryBackend."-binary_to_binary/4-lbc$^6/2-11-"/5
(nx 0.3.0) lib/nx/binary_backend.ex:673: Nx.BinaryBackend.element_wise_bin_op/4
(nx 0.3.0) lib/nx/defn/evaluator.ex:199: Nx.Defn.Evaluator.eval_apply/4
(nx 0.3.0) lib/nx/defn/evaluator.ex:81: Nx.Defn.Evaluator.eval/3
(elixir 1.13.4) lib/enum.ex:1715: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
(elixir 1.13.4) lib/enum.ex:1715: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
(nx 0.3.0) lib/nx/defn/evaluator.ex:173: Nx.Defn.Evaluator.eval_apply/4
We are doing a division by zero. We could return NaN but it may be a bug in Axon in this case?
Hey @seanmor5 , I've found where the division by zero is happening. It is here in the dropout definition:
out = Nx.select(mask, input / keep_prob, Nx.tensor(0, type: Nx.type(input)))
I'm a bit confused here with how this is working. I understand that the mask uses keep_prob
to lean towards keeping the value or not. However, on the second argument of that select I'm not clear on why is that division needed. If the mask is 1
for that element, shouldn't the element be preserved as it is?
@Ian-GL In most DL frameworks we scale the kept activations IOT preserve original input properties (mean, variance, etc.)
In this case 1.0 is definitely an unrealistic value so it's not surprising it fails. I'm just not sure if we should add a small epsilon to the denonminator or if we should raise if the value is >= 1
I will need to check the behavior of other frameworks
@Ian-GL If you want to tackle this, Keras
raises when the dropout rate is not on the interval [0, 1)
, and I think that's a sensible default!
Sure, @seanmor5 , thanks for the guidance! I'll make a PR.