meshed-memory-transformer icon indicating copy to clipboard operation
meshed-memory-transformer copied to clipboard

RuntimeError: gather(): Expected dtype int64 for index, in beam_search/beam_search.py, line 26, in fn

Open linhuixiao opened this issue 1 year ago • 1 comments

Meshed-Memory Transformer Evaluation Evaluation: 0%|
Evaluation: 0%| | 0/500 [00:00<?, ?it/s]

Traceback (most recent call last): File "test.py", line 78, in scores = predict_captions(model, dict_dataloader_test, text_field) File "test.py", line 26, in predict_captions out, _ = model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1) File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/captioning_model.py", line 70, in beam_search return bs.apply(visual, out_size, return_probs, **kwargs) File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/beam_search/beam_search.py", line 71, in apply visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs) File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/beam_search/beam_search.py", line 121, in iter self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size)) File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/containers.py", line 30, in apply_to_states self._buffers[name] = fn(self._buffers[name]) File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/beam_search/beam_search.py", line 26, in fn s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1, RuntimeError: gather(): Expected dtype int64 for index

linhuixiao avatar Jul 22 '22 08:07 linhuixiao

I have solved this bug:

this is a bug, please fix the code in models/beam_search.py line 118:

        # selected_beam = selected_idx / candidate_logprob.shape[-1]
        selected_beam = torch.div(selected_idx, candidate_logprob.shape[-1], rounding_mode="floor")  

linhuixiao avatar Jul 22 '22 08:07 linhuixiao