KeyError: 'global_step' When I load the weight of TAPIR
I downloaded tapir_checkpoint_panning.npy from the link provided, but it didn't load. I printed ckpt_state.keys() in the restore function in experiment_utils.py and found only 'params' and 'state' in the key of ckpt_state. There is no 'global_step' in the key. Below are my command to run the evaluation.
python3 ./tapnet/experiment.py --config=./tapnet/configs/tapir_config.py --jaxline_mode=eval_davis_points --config.checkpoint_dir=./tapnet/checkpoint/ --config.experiment_kwargs.config.davis_points_path=/tapvid_davis/tapvid_davis.pkl
How do I get the code to run? Thanks!
I guess we didn't test this codepath; this is more intended for evaluating checkpoints you've trained yourself.
The logic for restoring the Experiment object from the checkpoint is here: https://github.com/google-deepmind/tapnet/blob/main/utils/experiment_utils.py#L160 -- if you change that line to just initialize the global_step to 0, it should probably get you past that error. I'm not sure how much more modification it will need to re-intialize everything correctly, but hopefully not too much.
I guess we didn't test this codepath; this is more intended for evaluating checkpoints you've trained yourself.
The logic for restoring the Experiment object from the checkpoint is here: https://github.com/google-deepmind/tapnet/blob/main/utils/experiment_utils.py#L160 -- if you change that line to just initialize the global_step to 0, it should probably get you past that error. I'm not sure how much more modification it will need to re-intialize everything correctly, but hopefully not too much.
@cdoersch OK, I made a few simple changes that made the tapir_checkpoint_panning.npy load successfully, but I don't know if it would affect TAPIR's results. Hope to get your advice.
experiment_state.global_step = 0 if 'global_step' not in ckpt_state.keys() else int(ckpt_state['global_step'])
exp_mod = experiment_state.experiment_module
for attr, name in exp_mod.CHECKPOINT_ATTRS.items():
if name == 'opt_state':
name = 'state' if name not in ckpt_state.keys() else name
setattr(exp_mod, attr, utils.bcast_local_devices(ckpt_state[name]))
Interesting -- I'm fairly sure that isn't a correct initialization for opt_state (it should be an Adam state), but for evaluation I guess it doesn't matter? IIRC Jaxline will call the full experiment init (including parameters and opt_state) and then overwrite with checkpoint values, so maybe the right thing to do is just to not set any attributes at all if name=='opt_state'?
What would be the correct way to run the pretrained tapir model on some videos locally (gpu)?
Interesting -- I'm fairly sure that isn't a correct initialization for opt_state (it should be an Adam state), but for evaluation I guess it doesn't matter? IIRC Jaxline will call the full experiment init (including parameters and opt_state) and then overwrite with checkpoint values, so maybe the right thing to do is just to not set any attributes at all if name=='opt_state'?
@cdoersch Thank you for the suggestion, I'll give it a try.