The calculation of loss function is inconsistent with the paper
loss = F.mse_loss(expected_q_vals, current_q_values)
In this line of code, current_q_values are the mean of all branches Q values, the same with equation 6 in the paper, y.
However, the calculation of loss function in equation 7 uses y_d, whitch means Q values of each branch.
Moreover, the size of expected_q_vals is [128, 1], but the size of current_q_values is [128, 4]. This cause a warning
UserWarning: Using a target size (torch.Size([128, 4])) that is different to the input size (torch.Size([128, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
I think the loss function is the same after broadcasting in mse_loss.