axon icon indicating copy to clipboard operation
axon copied to clipboard

Add more examples

Open seanmor5 opened this issue 4 years ago • 7 comments
trafficstars

Willing to accept examples on different datasets and models to demonstrate different parts of the Axon API and to demonstrate Axon's viability in the ecosystem. The TensorFlow guides are a great place to look for different datasets and problems. If you're blocked on any specific issue feel free to comment on the relevant issue with your use case :)

seanmor5 avatar Apr 08 '21 20:04 seanmor5

Hi @seanmor5 I am trying to create an example to predict diabetes with Axon and Nx, but I am still trying to understand how it works.

Currently I have an error:

X = #Nx.Tensor< f32[615][8]

Y = #Nx.Tensor< s64[615]

This is the code I'm trying to create: https://gist.github.com/tiagodavi/a905abeaf4d1f92c21f9df9043d196fe

StreamExecutor device (0): Host, Default Version
** (ArgumentError) expected input shapes to be equal, got {615} != {615, 1}
    (axon 0.1.0-dev) lib/axon/shared.ex:22: anonymous fn/1 in Axon.Shared."__defn:assert_shape!__"/2
    (nx 0.1.0) lib/nx/defn/compiler.ex:114: Nx.Defn.Compiler.__remote__/4
    (axon 0.1.0-dev) lib/axon/losses.ex:122: Axon.Losses."__defn:binary_cross_entropy__"/3
    (axon 0.1.0-dev) lib/axon/loop.ex:325: anonymous fn/5 in Axon.Loop.train_step/3
    (nx 0.1.0) lib/nx/defn/grad.ex:20: Nx.Defn.Grad.transform/3
    (axon 0.1.0-dev) lib/axon/loop.ex:332: anonymous fn/4 in Axon.Loop.train_step/3
    (axon 0.1.0-dev) lib/axon/loop.ex:1135: anonymous fn/4 in Axon.Loop.build_batch_fn/2
    (nx 0.1.0) lib/nx/defn/compiler.ex:101: Nx.Defn.Compiler.runtime_fun/4

tiagodavi avatar Jan 20 '22 20:01 tiagodavi

Hi @tiagodavi! Axon's implementation of BCE expects y_true to have a last dimension of size 1 (there's an explicit check for shape equality between y_true and y_pred). If you add a new axis to your y_true: Nx.new_axis(y, -1) - then the error should go away.

We can probably relax the strictly equal shape constraint, feel free to open a PR otherwise I will open an issue to track.

Also, you might be interested in trying out [Explorer](https://github.com/elixir-nx/explorer) for easier Nx/Axon interop with structured data :)

seanmor5 avatar Jan 20 '22 20:01 seanmor5

Thank you @seanmor5 .

I was able to fix the error, but accuracy is quite bad. I am probably doing something wrong still.

model =
     input
     |> Axon.dense(features, activation: :relu)
     |> Axon.dense(features, activation: :relu)
     |> Axon.dense(1, activation: :sigmoid)

   trained_model =
     model
     |> Axon.Loop.trainer(:binary_cross_entropy, :adam)      
     |> Axon.Loop.run([{x_train, y_train}], epochs: 10, compiler: EXLA)  


   # trying to interpret sigmoid here
   result = 
     model
     |> Axon.predict(trained_model, x_test, compiler: EXLA)
     |> Nx.map([type: {:s, 64}], fn x -> 
       if x > 0.5, do: 1, else: 0
     end)
   
   IO.inspect Axon.Metrics.accuracy(y_test, result)

   #Nx.Tensor<
     f32
     0.3464052379131317
   >

tiagodavi avatar Jan 20 '22 21:01 tiagodavi

Axon's accuracy should do that thresholding for you. What do you get if you just feed the result of Axon.predict(model, trained_model, x_test, compiler: EXLA) into Axon.Metrics.accuracy?

seanmor5 avatar Jan 20 '22 22:01 seanmor5

Something like that?

result =     
      model
      |> Axon.predict(trained_model, x_test, compiler: EXLA)
      |> Axon.Metrics.accuracy(y_test)

    IO.inspect result

#Nx.Tensor< f32 0.013071895577013493

I'll take this course to see if I can learn it better: https://grox.io/language/nx/course

tiagodavi avatar Jan 20 '22 22:01 tiagodavi

It is probably a bug, please send me the gist!

seanmor5 avatar Jan 20 '22 23:01 seanmor5

Sure, this is the most updated one: https://gist.github.com/tiagodavi/a905abeaf4d1f92c21f9df9043d196fe

tiagodavi avatar Jan 21 '22 12:01 tiagodavi