BranchingDQN icon indicating copy to clipboard operation
BranchingDQN copied to clipboard

The calculation of loss function is inconsistent with the paper

Open Yaxin1996Z opened this issue 4 years ago • 1 comments

loss = F.mse_loss(expected_q_vals, current_q_values) In this line of code, current_q_values are the mean of all branches Q values, the same with equation 6 in the paper, y. However, the calculation of loss function in equation 7 uses y_d, whitch means Q values of each branch.

Moreover, the size of expected_q_vals is [128, 1], but the size of current_q_values is [128, 4]. This cause a warning UserWarning: Using a target size (torch.Size([128, 4])) that is different to the input size (torch.Size([128, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.

Yaxin1996Z avatar Dec 18 '21 02:12 Yaxin1996Z

I think the loss function is the same after broadcasting in mse_loss.

allenzren avatar Apr 22 '22 23:04 allenzren