Imagination-Augmented-Agents
Imagination-Augmented-Agents copied to clipboard
Mistake at Distil loss?
Hello
On your 4.imagination-augmented agent.ipynb
, you specify the distil loss function as the following:
distil_loss = 0.01 * (F.softmax(logit).detach() * F.log_softmax(distil_logit)).sum(1).mean()
Don't you forget the minus sign in front of 0.01?
Because, according to wiki, the cross entropy function between two distribution is:
(-1) * sum {p(x) * log(q(x))}
link for cross entropy definition