imitation-learning icon indicating copy to clipboard operation
imitation-learning copied to clipboard

Weight updates for branch heads

Open Amakri1020 opened this issue 4 years ago • 3 comments

Hi, I noticed that for every training sample the network outputs predictions for all 5 output branches, but the loss is then (correctly) calculated using the output from the branch that corresponds to that sample's high-level command and summing those losses for all samples in the batch to get the total_loss tensor. Is this total loss value then used to update all 5 branches? Or is an individual loss for each branch calculated somewhere only using the samples that they are supposed to predict on given the high-level command?

Hopefully the question is clear, I can try to rephrase if it isn't!

Thanks a lot for this repo it has been very useful!

Amakri1020 avatar Jul 18 '19 05:07 Amakri1020

Hi, I noticed that for every training sample the network outputs predictions for all 5 output branches, but the loss is then (correctly) calculated using the output from the branch that corresponds to that sample's high-level command and summing those losses for all samples in the batch to get the total_loss tensor.

You are right. The loss is accumulated adding the loss of separate branches exactly as you describe it.

Is this total loss value then used to update all 5 branches?

Through the gradients (derivate of loss with respect to the data) the weight will influences the branches that were present in the training batch.

markus-hinsche avatar Jul 18 '19 07:07 markus-hinsche

So if for example we have a batch of 20 images, 10 of them are Right and 10 of them are Straight, does this mean the Left and Follow branch heads are not updated at all for this batch? This would make sense to me but it doesn't seem reflected in the code, since only 1 total loss value is calculated and the entire network is trained based on this value.

Amakri1020 avatar Jul 18 '19 21:07 Amakri1020

So if for example we have a batch of 20 images, 10 of them are Right and 10 of them are Straight, does this mean the Left and Follow branch heads are not updated at all for this batch?

I think this is correct. In my opinion the code https://github.com/merantix/imitation-learning/blob/master/imitation/models/conditional_il_model.py#L53-L81 does that.

You could make an experiment: Train one batch only with data of just one condition, and see what happens to other conditions heads.

markus-hinsche avatar Jul 31 '19 12:07 markus-hinsche