ConvoSumm
ConvoSumm copied to clipboard
Issues using BART model for inference
I am trying to use scripts/prep.sh
and scripts/inference.py
to load /reddit_vanilla_actual/checkpoint_best.pt
BART for inference. I have been having many issues, mostly related to package versions and the extended 2048
source positions.
Environment:
pytorch 1.7.1 py3.8_cuda10.2.89_cudnn7.6.5_0 pytorch
And I tried installing fairseq
from source to access the examples
module, but then I saw you had your own copy of fairseq in this repo so I installed your version according to the instructions here
cd fairseq
pip install --editable ./
python setup.py build develop
I binarized val.source
and val.target
from and am running inference as such:
python scripts/inference.py /home/aadelucia/ConvoSumm/checkpoints/reddit_vanilla_actual checkpoint_best.pt /home/aadelucia/ConvoSumm/alexandra_test/data_processed /home/aadelucia/ConvoSumm/alexandra_test/data/val.source /home/aadelucia/ConvoSumm/alexandra_test/inference_output.txt 4 1 80 120 1 2048 ./misc/encoder.json ./misc/vocab.bpe
And I get the following error:
Traceback (most recent call last):
File "scripts/inference.py", line 42, in <module>
hypotheses_batch = bart.sample(slines, beam=beam, lenpen=lenpen, min_len=min_len, no_repeat_ngram_size=3)
File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/hub_utils.py", line 132, in sample
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/models/bart/hub_interface.py", line 108, in generate
return super().generate(
File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/hub_utils.py", line 171, in generate
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/hub_utils.py", line 258, in _build_batches
batch_iterator = self.task.get_batch_iterator(
File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/tasks/fairseq_task.py", line 244, in get_batch_iterator
batch_sampler = dataset.batch_by_size(
File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/data/fairseq_dataset.py", line 145, in batch_by_size
return data_utils.batch_by_size(
File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/data/data_utils.py", line 337, in batch_by_size
return batch_by_size_vec(
File "fairseq/data/data_utils_fast.pyx", line 20, in fairseq.data.data_utils_fast.batch_by_size_vec
File "fairseq/data/data_utils_fast.pyx", line 27, in fairseq.data.data_utils_fast.batch_by_size_vec
AssertionError: Sentences lengths should not exceed max_tokens=1024
Am I using the wrong version of a package? Is there something extra needed for this to work?
Nevermind, seems to be working when I pass in max_tokens=max_source_positions
in scripts/inference.py
bart = BARTModel.from_pretrained(
model_dir,
checkpoint_file=model_file,
data_name_or_path=bin_folder,
gpt2_encoder_json=encoder_file,
gpt2_vocab_bpe=vocab_file,
max_source_positions=max_source_positions,
max_tokens=max_source_positions
)