trax
trax copied to clipboard
How to train network with multiple prediction head?
Description
My network has m inputs (X1, X2, ...Xm), and n prediction heads (Y1, Y2, ...Yn). Each of the input and output may have different shapes. Each of the prediction head has its own different loss function, Li for Yi. The total loss is the sum of the n loss functions.
My input data shape is:
- data = [X1, X2, ... Xm] <-- cannot be concatenated due to incompatible shapes
- Y = [Y1, Y2, ... Yn] <-- cannot be concatenated due to incompatible shapes.
After data goes through the network I have:
- Y' = [Y1', Y2', ... Yn'] <-- predictions
- Y = [Y1, Y2, ... Yn]
For loss calculation, what I would like to do is to take a "zip" like layer and convert the above to
- [(Y1', Y1), (Y2', Y2), ... (Yn', Yn)]
And then apply tl.Serial(tl.Parallel(L1, L2, ... Ln), tl.Add())
.
I searched through the documentation and do not find such a "zip" layer.
I need your advise;
thank you very much!
Environment information
OS: <your answer here>
$ pip freeze | grep trax
# your output here
$ pip freeze | grep tensor
# your output here
$ pip freeze | grep jax
# your output here
$ python -V
# your output here
For bugs: reproduction and error logs
# Steps to reproduce:
...
# Error logs:
...