tapnet icon indicating copy to clipboard operation
tapnet copied to clipboard

KeyError: 'global_step' When I load the weight of TAPIR

Open KiritoHarlod opened this issue 1 year ago • 5 comments

I downloaded tapir_checkpoint_panning.npy from the link provided, but it didn't load. I printed ckpt_state.keys() in the restore function in experiment_utils.py and found only 'params' and 'state' in the key of ckpt_state. There is no 'global_step' in the key. Below are my command to run the evaluation. python3 ./tapnet/experiment.py --config=./tapnet/configs/tapir_config.py --jaxline_mode=eval_davis_points --config.checkpoint_dir=./tapnet/checkpoint/ --config.experiment_kwargs.config.davis_points_path=/tapvid_davis/tapvid_davis.pkl How do I get the code to run? Thanks!

KiritoHarlod avatar Apr 28 '24 13:04 KiritoHarlod

I guess we didn't test this codepath; this is more intended for evaluating checkpoints you've trained yourself.

The logic for restoring the Experiment object from the checkpoint is here: https://github.com/google-deepmind/tapnet/blob/main/utils/experiment_utils.py#L160 -- if you change that line to just initialize the global_step to 0, it should probably get you past that error. I'm not sure how much more modification it will need to re-intialize everything correctly, but hopefully not too much.

cdoersch avatar Apr 28 '24 18:04 cdoersch

I guess we didn't test this codepath; this is more intended for evaluating checkpoints you've trained yourself.

The logic for restoring the Experiment object from the checkpoint is here: https://github.com/google-deepmind/tapnet/blob/main/utils/experiment_utils.py#L160 -- if you change that line to just initialize the global_step to 0, it should probably get you past that error. I'm not sure how much more modification it will need to re-intialize everything correctly, but hopefully not too much.

@cdoersch OK, I made a few simple changes that made the tapir_checkpoint_panning.npy load successfully, but I don't know if it would affect TAPIR's results. Hope to get your advice.

experiment_state.global_step = 0 if 'global_step' not in ckpt_state.keys() else int(ckpt_state['global_step'])
exp_mod = experiment_state.experiment_module
for attr, name in exp_mod.CHECKPOINT_ATTRS.items():
  if name == 'opt_state':
    name = 'state' if name not in ckpt_state.keys() else name
  setattr(exp_mod, attr, utils.bcast_local_devices(ckpt_state[name]))

KiritoHarlod avatar Apr 29 '24 04:04 KiritoHarlod

Interesting -- I'm fairly sure that isn't a correct initialization for opt_state (it should be an Adam state), but for evaluation I guess it doesn't matter? IIRC Jaxline will call the full experiment init (including parameters and opt_state) and then overwrite with checkpoint values, so maybe the right thing to do is just to not set any attributes at all if name=='opt_state'?

cdoersch avatar Apr 29 '24 09:04 cdoersch

What would be the correct way to run the pretrained tapir model on some videos locally (gpu)?

jaoguerreiro avatar Apr 29 '24 09:04 jaoguerreiro

Interesting -- I'm fairly sure that isn't a correct initialization for opt_state (it should be an Adam state), but for evaluation I guess it doesn't matter? IIRC Jaxline will call the full experiment init (including parameters and opt_state) and then overwrite with checkpoint values, so maybe the right thing to do is just to not set any attributes at all if name=='opt_state'?

@cdoersch Thank you for the suggestion, I'll give it a try.

KiritoHarlod avatar Apr 30 '24 03:04 KiritoHarlod