Medical-Dialogue-System
Medical-Dialogue-System copied to clipboard
some details
Hello,
it seems there are some problems?
in bertGPT/generate.py, the past_length
should start from 0 following the training settings.
length = 1
# decoding loop
for i in range(100):
mask = F.pad(mask, (0, 1), "constant", 1.0)
logits, past = decoder(prev_pred, mask, past=past, past_length=length)
logits = logits.squeeze(1)
logits = top_k_logits(logits, k=top_k)
probs = F.softmax(logits, dim=-1)
prev_pred = torch.multinomial(probs, num_samples=1)
sentence.append(prev_pred)
if prev_pred[0][0] == 102:
break
length += 1
in gpt2/gpt2_test_mmi.py, the sequence curr_input_tensors
should be truncated from the front.
curr_input_tensors = curr_input_tensors[:, 0:300] ×
curr_input_tensors = curr_input_tensors[:, -300:] √