mamba
mamba copied to clipboard
Kernel optimisations for different objectives
Hi! I'm currently implementing lora in order to be able to train this model on lower end consumer grade gpus. For the liner proj there are no issues and the stability/mem usage seems ok, the problem is in the selective scan kernels as well as in the layer norm operations, which gives massive spikes in memory. It would be cool you could give me some advice on where to modify the code to be able to trade off some speed for a reduced memory consumption.
I don't get why there would be a spike in memory. Are you saying it happens on consumer grade GPUs and not on data center GPUs?
I don't think we allocate a lot of extra memory anywhere.
I hadn't tried on datacenter gpus yet. With my implementation I managed to bring down memory usage to about 8/9gb at 4096 ctx window, but running the profiler shows that both the selective scan fwd and the layer norm make huge spikes which, with long sequences leads to oom
For layer norm specifically, what if you use torch.nn.Layernorm? If the memory blows up there then it's a general problem and not specific to Mamba model.
I will try, I tried using unsloth implementation but almost the same, For the selective scan?
I've no idea why there would be a memory spike, if you figure out let me know.
Ok cool!