trax icon indicating copy to clipboard operation
trax copied to clipboard

How to train network with multiple prediction head?

Open aaalgo opened this issue 4 years ago • 0 comments

Description

My network has m inputs (X1, X2, ...Xm), and n prediction heads (Y1, Y2, ...Yn). Each of the input and output may have different shapes. Each of the prediction head has its own different loss function, Li for Yi. The total loss is the sum of the n loss functions.

My input data shape is:

  • data = [X1, X2, ... Xm] <-- cannot be concatenated due to incompatible shapes
  • Y = [Y1, Y2, ... Yn] <-- cannot be concatenated due to incompatible shapes.

After data goes through the network I have:

  • Y' = [Y1', Y2', ... Yn'] <-- predictions
  • Y = [Y1, Y2, ... Yn]

For loss calculation, what I would like to do is to take a "zip" like layer and convert the above to

  • [(Y1', Y1), (Y2', Y2), ... (Yn', Yn)]

And then apply tl.Serial(tl.Parallel(L1, L2, ... Ln), tl.Add()).

I searched through the documentation and do not find such a "zip" layer.

I need your advise;

thank you very much!

Environment information

OS: <your answer here>

$ pip freeze | grep trax
# your output here

$ pip freeze | grep tensor
# your output here

$ pip freeze | grep jax
# your output here

$ python -V
# your output here

For bugs: reproduction and error logs

# Steps to reproduce:
...
# Error logs:
...

aaalgo avatar Jan 15 '21 01:01 aaalgo