meshed-memory-transformer
meshed-memory-transformer copied to clipboard
RuntimeError: gather(): Expected dtype int64 for index, in beam_search/beam_search.py, line 26, in fn
Meshed-Memory Transformer Evaluation
Evaluation: 0%|
Evaluation: 0%| | 0/500 [00:00<?, ?it/s]
Traceback (most recent call last):
File "test.py", line 78, in
I have solved this bug:
this is a bug, please fix the code in models/beam_search.py line 118:
# selected_beam = selected_idx / candidate_logprob.shape[-1]
selected_beam = torch.div(selected_idx, candidate_logprob.shape[-1], rounding_mode="floor")