Grounded-Image-Captioning
Grounded-Image-Captioning copied to clipboard
Key Error during training
Hi,
I am trying to use your codebase for some experiments and during training I get the following error:
$ python train.py --id CE-scan-sup-0.1kl --caption_model topdown --input_json data/flickrtalk.json --input_fc_dir data/flickrbu/flickrbu_fc --input_att_dir data/flickrbu/flickrbu_att --input_box_dir data/flickrbu/flickrbu_box --input_label_h5 data/flickrtalk_label.h5 --batch_size 29 --learning_rate 5e-4 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path log/CE-scan-sup-0.1kl --save_checkpoint_every 1000 --val_images_use -1 --max_epochs 30 --att_supervise True --att_supervise_weight 0.1 Constructing SCAN model... scan_model_path:misc/SCAN/runs/f30k_SCAN_POS1/checkpoint/model_best.pth.tar Done tensorboardX is not installed DataLoader loading json file: data/flickrtalk.json vocab size is 7000 DataLoader loading h5 file: data/flickrbu/flickrbu_fc data/flickrbu/flickrbu_att data/flickrbu/flickrbu_box data/flickrtalk_label.h5 max sequence length in data is 16 read 31014 image features assigned 29000 images to split train assigned 1014 images to split val assigned 1000 images to split test Read data: 0.7768440246582031 Traceback (most recent call last): File "train.py", line 291, in <module> train(opt) File "train.py", line 180, in train model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, box_inds) File "/users/vlad/anaconda3/envs/gvd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__ result = self.forward(*input, **kwargs) File "/users/vlad/anaconda3/envs/gvd/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/users/vlad/anaconda3/envs/gvd/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/users/vlad/anaconda3/envs/gvd/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply raise output File "/users/vlad/anaconda3/envs/gvd/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker output = module(*input, **kwargs) File "/users/vlad/anaconda3/envs/gvd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__ result = self.forward(*input, **kwargs) File "/users/vlad/image_grounding/Grounded-Image-Captioning/misc/loss_wrapper.py", line 45, in forward _, grd_weights,noun_mask= get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, labels[:,1:].detach(), vars(self.opt)) File "/users/vlad/image_grounding/Grounded-Image-Captioning/misc/rewards.py", line 68, in get_self_critical_reward gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)} File "/users/vlad/image_grounding/Grounded-Image-Captioning/misc/rewards.py", line 68, in <dictcomp> gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)} KeyError: 29 Terminating BlobFetcher
I followed all the steps described and I am able to run the evaluation, but for training I get the above error. I did some investigations and I think a reshape is missing somewhere since the dataloader loads multiple captions per image but this doesn't seem to be reflected here https://github.com/YuanEZhou/Grounded-Image-Captioning/blob/77295a6e36de817f173435e809effc3396469ee3/misc/rewards.py#L70 Any suggestions on how to overcome this?
Thanks!
I have narrowed it down: the problem seems to be the https://github.com/YuanEZhou/Grounded-Image-Captioning/blob/77295a6e36de817f173435e809effc3396469ee3/train.py#L97 which implies that the https://github.com/YuanEZhou/Grounded-Image-Captioning/blob/77295a6e36de817f173435e809effc3396469ee3/misc/rewards.py#L44 is wrongly computed.
I git clone the code and it can run correctly.
You can set a breakpoint in /users/vlad/image_grounding/Grounded-Image-Captioning/misc/rewards.py", line 68
and check why you got the key 29
(instead of 28) which causes the keyerror.
(i % batch_size // seq_per_img) ==28 when i==289, batch_size ==29*5 seq_per_img==5.