xlstm icon indicating copy to clipboard operation
xlstm copied to clipboard

Stateful training doesn't seem to work

Open bmilde opened this issue 5 months ago • 1 comments

I'm trying to use xlstm for stateful training, i.e. propagating a (detached) state between consecutive segments. To keep things simple I'm using the native pytorch version (no triton kernels). The following code works with traditional LSTMs (nn.lstm), but not with xlstm:

     while True:
[...]
            with autocast(device_type=device_str, dtype=torch.float16):
                if input_state:
                    input_state = detach_states(input_state)
                    #input_state = copy.deepcopy(input_state)
                    if args.debug:
                        assert_all_detached(input_state)
                enc_out, output_state = model(feats, input_state)
                logp = enc_out.log_softmax(-1).transpose(0, 1)
                loss = criterion(logp, tokens, in_lens, tgt_lens)
                loss = loss / args.accumulation_steps 
[...]
                input_state = output_state

My code also works with xlstm when passing in None as state (=initializing a new state every time). When trying to pass in a detached xlstm state from a previous optimizer update I'm getting the following error (with torch.autograd.set_detect_anomaly(True) enabled):

/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/autograd/graph.py:824: UserWarning: Error detected in MulBackward0. Traceback of forward call that caused the error: File "/scratch/statecatcher/train.py", line 495, in train(args) File "/scratch/statecatcher/train.py", line 393, in train enc_out, output_state = model(feats, input_state) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/scratch/statecatcher/model.py", line 128, in forward logits, new_states = self.encoder(feats, states) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/xlstm/xlstm_large/model.py", line 150, in forward x, state = self.backbone(x, state) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/xlstm/xlstm_large/model.py", line 221, in forward x, block_state_new = block(x, block_state) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/xlstm/xlstm_large/model.py", line 507, in forward x_mlstm, state = self.mlstm_layer(x_mlstm, state) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/xlstm/xlstm_large/model.py", line 429, in forward h, state = self.mlstm_backend( File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/scratch/statecatcher/venv/lib/python3.12/site-packages/mlstm_kernels/torch/backend_module.py", line 179, in forward return self._train_fn( File "/scratch/statecatcher/venv/lib/python3.12/site-packages/mlstm_kernels/torch/chunkwise/native/fwbw.py", line 208, in mlstm_chunkwise__native_autograd matH_out, _, _, last_states, _ = mlstm_chunkwise_fw( File "/scratch/statecatcher/venv/lib/python3.12/site-packages/mlstm_kernels/torch/chunkwise/native/fw.py", line 268, in mlstm_chunkwise_fw matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise__recurrent_fw_C( File "/scratch/statecatcher/venv/lib/python3.12/site-packages/mlstm_kernels/torch/chunkwise/native/fw.py", line 116, in mlstm_chunkwise__recurrent_fw_C vecN_k_next = scaGbar_k * vecN_k + matK_chunk_gated.transpose(-2, -1).sum(-1) (Triggered internally at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass Traceback (most recent call last): File "/scratch/statecatcher/train.py", line 495, in train(args) File "/scratch/statecatcher/train.py", line 398, in train scaler.scale(loss).backward() File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/_tensor.py", line 648, in backward torch.autograd.backward( File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/autograd/init.py", line 353, in backward _engine_run_backward( File "/scratch/statecatcher/venv/lib/python3.12/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.HalfTensor [8, 2, 20]], which is output 0 of torch::autograd::CopyBackwards, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck! (venv) me@shelly:/scratch/statecatcher$

This hints at a problem where something is modified in place in the model itself when passing in detached states for the xlstm state initialization.

I'm using this code for detaching the states and also asserted that all tensors in the nested xlstm state structure (dict of tuples of tensors) are detached:

def assert_all_detached(x):
    if isinstance(x, torch.Tensor):
        assert not x.requires_grad, "Tensor still requires grad"
    elif isinstance(x, (list, tuple)):
        for v in x:
            assert_all_detached(v)
    elif isinstance(x, dict):
        for v in x.values():
            assert_all_detached(v)

def detach_states(states):
    print(states)
    """Recursively detach all tensors in nested state structures (dicts, tuples, lists)."""
    if states is None:
        return None
    elif isinstance(states, torch.Tensor):
        return states.detach()
    elif isinstance(states, dict):
        return {k: detach_states(v) for k, v in states.items()}
    elif isinstance(states, tuple):
        return tuple(detach_states(s) for s in states)
    elif isinstance(states, list):
        return [detach_states(s) for s in states]
    else:
        # Catch any unexpected data type (e.g., numbers, strings)
        return states

The nested state detaching is necessary since the state object can be summarized as:


{
    0: (
        tensor([[[[...], [...], ...]]], device='cuda:0', dtype=torch.float16, grad_fn=<SliceBackward0>),
        tensor([[[-0., -0., ...], ...]], device='cuda:0', dtype=torch.float16, grad_fn=<SliceBackward0>),
        tensor([[[0.1024], [-0.0309]], ...], device='cuda:0', dtype=torch.float16, grad_fn=<SliceBackward0>)
    ),
    1: (
        tensor([[[[...], [...], ...]]], device='cuda:0', dtype=torch.float16, grad_fn=<SliceBackward0>),
        tensor([[[-0., -0., ...], ...]], device='cuda:0', dtype=torch.float16, grad_fn=<SliceBackward0>),
        tensor([[[0.1024], [-0.0309]], ...], device='cuda:0', dtype=torch.float16, grad_fn=<SliceBackward0>)
    ),
    2: (
        tensor([[[[...], [...], ...]]], device='cuda:0', dtype=torch.float16, grad_fn=<SliceBackward0>),
        tensor([[[-0., -0., ...], ...]], device='cuda:0', dtype=torch.float16, grad_fn=<SliceBackward0>),
        tensor([[[0.1024], [-0.0309]], ...], device='cuda:0', dtype=torch.float16, grad_fn=<SliceBackward0>)
    )
}

bmilde avatar Jul 25 '25 12:07 bmilde

I got it running without errors by adding .clone() to:


mlstm_kernels/torch/chunkwise/native/fw.py, line 116, in mlstm_chunkwise__recurrent_fw_C
vecN_k_next = scaGbar_k * vecN_k + matK_chunk_gated.transpose(-2, -1).sum(-1)

=>

vecN_k_next = scaGbar_k * vecN_k.clone() + matK_chunk_gated.transpose(-2, -1).sum(-1)

and in the same file line 111:

        # NOTE: no update in-place (i.e. +=) as this gives error for autograd backward
        matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(
            -2, -1
        ) @ (matV_chunk)

=>

        # NOTE: no update in-place (i.e. +=) as this gives error for autograd backward
        matC_k_next = scaGbar_k[..., None] * matC_k.clone() + matK_chunk_gated.transpose(
            -2, -1
        ) @ (matV_chunk)

Not sure if this breaks something else though.

bmilde avatar Jul 25 '25 16:07 bmilde