lerobot icon indicating copy to clipboard operation
lerobot copied to clipboard

split dataset into train and val, and log val loss during training

Open tlpss opened this issue 1 year ago • 12 comments

What this does

Modifies train script to split dataset into train and val datasets and logs validation loss during training.

Based on #158

closes #250

How it was tested

running the train script with the default settings

python scripts/train.py

How to checkout & try? (for the reviewer)

Provide a simple way for the reviewer to try out your changes.

python scripts/train.py

@alexander-soare : feel free to provide feedback on this initial draft!

TODOs

  • [ ] modify all policy configurations to add the new parameters
  • [ ] document feature

tlpss avatar Jun 19 '24 12:06 tlpss

@tlpss Thanks for your help on this!

Could you please plot the validation metrics versus the success rate on simulation environment, on Aloha/Act or Pusht/Diffusion? I am very curious to see if it helps to find the best checkpoint.

cc @marinabar who started experimenting on this, and @alexander-soare @thomwolf who might be interested as well

Cadene avatar Jun 19 '24 13:06 Cadene

@Cadene, good suggestion!

I've started a train run for diffusion on the pushT task, will post the the results here later today or tomorrow

tlpss avatar Jun 19 '24 14:06 tlpss

@Cadene @marinabar

I trained on the PushT env with diffusion policy using all default settings.

The (surprising) results are as follows: image

  • validation loss is increasing during training, instead of decreasing
  • success ratio seems to be positively correlated with this increasing validation loss.. so the validation loss seems to be a lousy predictor for evaluation performance.

I would expect the validation loss to plateau due to the multimodality of the task (agent can take other solution than the one in the demonstrations), but not to increase so consistently and this makes me somewhat suspicious. I've started a second run on the Aloha sim dataset, according to the training tips here the validation loss should indeed plateau. If it is increasing again, there might be a bug in my code.

But so far, seems that the validation loss is not that useful as a proxy for checkpoint performance.

tlpss avatar Jun 20 '24 09:06 tlpss

@tlpss thanks for your awesome work so far! Could you please also plot training loss against that? What's most curious to me about the validation loss is that it doesn't start high then go down (even if it comes back up again).

alexander-soare avatar Jun 20 '24 10:06 alexander-soare

@alexander-soare

Same plot + train loss (switched to log scale)

pusht + diffusion: image

aloha-insertion + act: image

the aloha + ACT val loss curve is more what I expected. Probably also because the tasks has more 'critical' states (cf this paper).

Nonetheless the validation loss still does not seem to correlate well with the evaluation success rate.

tlpss avatar Jun 20 '24 12:06 tlpss

@tlpss insertion is probably not a great task as the success rate is low even in the paper. Owing to that, I think it's quite noisy and susceptible to uninteresting variations in your setup. Transfer is a good one.

I'd be really interested to see the validation loss early on for the first set of curves, to see if it starts high. Since it's cheap to compute (relative to full evaluations) you could try setting the validation to be very frequent.

Also, if you validate on the training data, do you see what you expect (matches the training curve)?

Btw don't feel obliged to try all these things! Just spitballing here.

alexander-soare avatar Jun 20 '24 12:06 alexander-soare

@tlpss here's what I got for validation loss with ACT on Aloha Transfer Cube :

Screenshot 2024-06-20 at 20 22 30

So the validation loss does indeed start high @alexander-soare And very curious to see your results!

marinabar avatar Jun 20 '24 18:06 marinabar

So I ran with the transfer cube task now and also made some better plots.

Pusht + diffusion: image

correlation between eval success and val loss: 0.73 correlation between eval success and step: 0.8

aloha transfer cube + ACT: image

correlation between eval success and val loss: -0.57 correlation between eval success and step: 0.63

comparison of the validation loss and succes ranks:

validation/val_loss	eval/pc_success	success_rank
240	0.218958	70.0	5.0
320	0.219605	72.0	4.0
280	0.220379	66.0	7.5
260	0.220709	68.0	6.0
220	0.220934	86.0	1.0
300	0.220936	76.0	3.0
200	0.221082	56.0	12.0
160	0.221141	66.0	7.5
180	0.221228	64.0	9.0
140	0.221485	82.0	2.0
120	0.224542	52.0	14.0
100	0.228353	62.0	10.5
80	0.231218	48.0	15.0
60	0.232430	54.0	13.0
40	0.249941	62.0	10.5
20	0.255676	46.0	16.0

Seems like

  1. the validation loss behaves as expected for the transfer cube task (high initially, then plateaus)
  2. the validation is somewhat indicative of relative succes rates, but it seems like 'time' is a better predictor. that is, based on these two runs, it seems like testing the N latest checkpoints is better than testing the N checkpoints with lowest validation loss.

tlpss avatar Jun 21 '24 09:06 tlpss

@tlpss insertion is probably not a great task as the success rate is low even in the paper. Owing to that, I think it's quite noisy and susceptible to uninteresting variations in your setup. Transfer is a good one.

thanks for the hint! I added a plot for the transfer task above.

Also, if you validate on the training data, do you see what you expect (matches the training curve)? That is a good suggestion, shoul've thought about doing this sanity check.

For ACT, the val loss is similar but not equal to the train loss, I believe this is due to the KL loss and the difference between train and eval mode of the z vector? image

For Diffusion policy, the losses are similar, and I guess differences are also due to eval vs train mode?

(updated with longer run) image

tlpss avatar Jun 21 '24 09:06 tlpss

@tlpss here's what I got for validation loss with ACT on Aloha Transfer Cube :

Screenshot 2024-06-20 at 20 22 30 So the validation loss does indeed start high @alexander-soare And very curious to see your results!

Hi @marinabar,

Thanks for the plot! I think my loss values are quite a bit higher than yours? Maybe I've made a mistake somewhere, do you have a code snippet that I can compare my code with?

tlpss avatar Jun 21 '24 09:06 tlpss

@tlpss thanks, you are delivering massive value :D The results are interesting and IMO beg many more questions.

For ACT, the val loss is similar but not equal to the train loss, I believe this is due to the KL loss and the difference between train and eval mode of the z vector?

I believe the KL-div loss should be calculated the same way regardless of eval vs training mode but it might need closer looking it (if it's interesting to anyone).

alexander-soare avatar Jun 21 '24 11:06 alexander-soare

@alexander-soare

I think I'll close this one?

For now my main question was if validation loss can be used as predictor for evaluation performance.

My (premature) conlusion is that validation loss as a proxy to actual evaluaton peformance is not that useful for behavior cloning, due to multimodality of the action distributions (the action in the validation episode was not the only good choice..) . For tasks with limited action distribution multimodality (i.e. a lot of critical states) it can serve as predictor, but otherwise it is rather uncorrelated.

This leaves real-world policy evaluation as the only option unfortunately. Curious to hear if anyone has suggestions on how to limit the amount of real-world evaluation that is required to compare methods/checkpoints/sensor input combinations...

tlpss avatar Jul 31 '24 14:07 tlpss