Some clarification questions about carry
First of all thank you for releasing the source code! This is a very cool paper, I am genuinely surprised that it worked this well. While reading the source code I got a bit confused about the role of carry (probably missing something obvious), thus making this post.
It seems like carry is only initialized once here, https://github.com/sapientinc/HRM/blob/4047578a02e5deba975c38a1f32392547e66c071/pretrain.py#L220 and is resetted only for halted sequence https://github.com/sapientinc/HRM/blob/4047578a02e5deba975c38a1f32392547e66c071/models/hrm/hrm_act_v1.py#L242 https://github.com/sapientinc/HRM/blob/4047578a02e5deba975c38a1f32392547e66c071/models/hrm/hrm_act_v1.py#L174-L179
which is at most 16 gradient steps if we use the default config https://github.com/sapientinc/HRM/blob/4047578a02e5deba975c38a1f32392547e66c071/config/arch/hrm_v1.yaml#L7
does that imply we are passing the hidden state across different examples? i.e.
z = z_init
for e in range(epochs):
for x, y_true in train_dataloader:
z, y_hat = hrm(z, x)
loss = softmax_cross_entropy(y_hat, y_true)
z = z.detach()
loss.backward()
opt.step()
opt.zero_grad()
instead of what is described in the paper (fig. 4)
for x, y_true in train_dataloader:
z = z_init
for step in range(N_supervision):
z, y_hat = hrm(z, x)
loss = softmax_cross_entropy(y_hat, y_true)
z = z.detach()
loss.backward()
opt.step()
opt.zero_grad()
Also, a follow up question: the default config uses H_cycles: 2 and L_cycles: 2
https://github.com/sapientinc/HRM/blob/4047578a02e5deba975c38a1f32392547e66c071/config/arch/hrm_v1.yaml#L9-L10
which is probably insufficient for $z_H$ to be close to $z_H^*$ (fix point of $f_H$). Wouldn't the 1-step gradient approximation be very poor then?
I was also confused by this, but I think the answer is here: https://github.com/sapientinc/HRM/blob/4047578a02e5deba975c38a1f32392547e66c071/models/hrm/hrm_act_v1.py#L246
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
This only updates current data to the new batch if halted is True. Otherwise, you will keep the previous data (that the carry was initialised on).
This is a fantastic thread...Thanks, @robert1003, for asking this - I had the exact same confusion while reading the pretrain.py script. The "stateful" nature of the carry across batches seemed to contradict the per-example deep supervision described in the paper.
And great find, @alec-tschantz, that torch.where line is indeed the key. It's a really elegant, classic PyTorch solution to the problem.
Just to synthesize my understanding: the implementation effectively vectorizes the M-loop ("Thinking Sessions") by processing a continuous stream of data. The carry is stateful, but the torch.where logic acts as a manager that selectively refreshes slots in the batch. So, for any given sample, the model does continue working on the same x data for multiple M steps if it hasn't halted, which perfectly matches the paper's description. The carry's state only "leaks" in the sense that it's the same tensor, but the data it's operating on is managed on a per-sample basis.
This implementation detail also directly answers your follow-up question, @robert1003. The effective computational depth for a single, difficult example isn't just N*T, but can be up to M_max * N * T. With the default configs, that's 16 * 2 * 2 = 64 total recurrent steps. This longer effective horizon makes the one-step gradient approximation (which assumes proximity to a fixed point) much more plausible than if the model were only running for 4 steps.
Looks like a clever & efficient way to implement the deep supervision logic. Thanks again to both of you for clarifying this
@narvind2003 not sure if I agree with you on this paragraph
This implementation detail also directly answers your follow-up question, @robert1003. The effective computational depth for a single, difficult example isn't just N*T, but can be up to M_max * N * T. With the default configs, that's 16 * 2 * 2 = 64 total recurrent steps. This longer effective horizon makes the one-step gradient approximation (which assumes proximity to a fixed point) much more plausible than if the model were only running for 4 steps.
it does run 64 total recurrent step, however they perform gradient update every N*T = 4 step