Medical-Dialogue-System icon indicating copy to clipboard operation
Medical-Dialogue-System copied to clipboard

some details

Open kaishxu opened this issue 3 years ago • 0 comments

Hello, it seems there are some problems? in bertGPT/generate.py, the past_length should start from 0 following the training settings.

length = 1
# decoding loop
for i in range(100):
    mask = F.pad(mask, (0, 1), "constant", 1.0)
    logits, past = decoder(prev_pred, mask, past=past, past_length=length)
    logits = logits.squeeze(1)
    logits = top_k_logits(logits, k=top_k)
    probs = F.softmax(logits, dim=-1)
    prev_pred = torch.multinomial(probs, num_samples=1)
    sentence.append(prev_pred)
    if prev_pred[0][0] == 102:
        break
    length += 1

in gpt2/gpt2_test_mmi.py, the sequence curr_input_tensors should be truncated from the front.

curr_input_tensors = curr_input_tensors[:, 0:300]   ×
curr_input_tensors = curr_input_tensors[:, -300:]   √

kaishxu avatar Jun 22 '21 13:06 kaishxu