IQN-and-Extensions copied to clipboard
IQN-DQN.ipynb max over taus instead of max over actions?
Given that forward() will return tuple:
return out.view(batch_size, num_tau, self.num_actions), taus
Should we use .max(1) instead of .max(2) ?
Currently it is:
Q_targets_next = Q_targets_next.detach().max(2)[0].unsqueeze(1) # (batch_size, 1, N)
Maybe should be:
Q_targets_next = Q_targets_next.detach().max(1)[0].unsqueeze(1) # (batch_size, 1, numActions)
In other words, to find the maximum in every tau group, rather than across every action? Sorry if I misunderstood the process.
def learn(self, experiences):
"""Update value parameters using given batch of experience tuples.
experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
gamma (float): discount factor
states, actions, rewards, next_states, dones = experiences
# Get max predicted Q values (for next states) from target model
Q_targets_next, _ = self.qnetwork_target(next_states)
Q_targets_next = Q_targets_next.detach().max(2)[0].unsqueeze(1) # (batch_size, 1, N) <-------------------------------------- HERE
# Compute Q targets for current states
Q_targets = rewards.unsqueeze(-1) + (self.GAMMA**self.n_step * Q_targets_next * (1. - dones.unsqueeze(-1)))
# Get expected Q values from local model
Q_expected, taus = self.qnetwork_local(states)
Q_expected = Q_expected.gather(2, actions.unsqueeze(-1).expand(self.BATCH_SIZE, 8, 1))
# Quantile Huber loss
td_error = Q_targets - Q_expected
assert td_error.shape == (self.BATCH_SIZE, 8, 8), "wrong td error shape"
huber_l = calculate_huber_loss(td_error, 1.0)
quantil_l = abs(taus -(td_error.detach() < 0).float()) * huber_l / 1.0
loss = quantil_l.sum(dim=1).mean(dim=1) # , keepdim=True if per weights get multipl
loss = loss.mean()
# Minimize the loss
# ------------------- update target network ------------------- #
self.soft_update(self.qnetwork_local, self.qnetwork_target)
return loss.detach().cpu().numpy()
Hey, I haven't had a look at the code for a while. Can you reference the part in the paper that made you think it's the max over the taus?
Hi Sebastian,
From page 5 of this paper These equations are a bit tough for me, but looking at equation 2 and 3 from here: