mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Kernel optimisations for different objectives

Open mlinmg opened this issue 1 year ago • 6 comments
trafficstars

Hi! I'm currently implementing lora in order to be able to train this model on lower end consumer grade gpus. For the liner proj there are no issues and the stability/mem usage seems ok, the problem is in the selective scan kernels as well as in the layer norm operations, which gives massive spikes in memory. It would be cool you could give me some advice on where to modify the code to be able to trade off some speed for a reduced memory consumption.

mlinmg avatar Jan 20 '24 10:01 mlinmg

I don't get why there would be a spike in memory. Are you saying it happens on consumer grade GPUs and not on data center GPUs?

I don't think we allocate a lot of extra memory anywhere.

tridao avatar Jan 24 '24 04:01 tridao

I hadn't tried on datacenter gpus yet. With my implementation I managed to bring down memory usage to about 8/9gb at 4096 ctx window, but running the profiler shows that both the selective scan fwd and the layer norm make huge spikes which, with long sequences leads to oom

mlinmg avatar Jan 24 '24 08:01 mlinmg

For layer norm specifically, what if you use torch.nn.Layernorm? If the memory blows up there then it's a general problem and not specific to Mamba model.

tridao avatar Jan 24 '24 08:01 tridao

I will try, I tried using unsloth implementation but almost the same, For the selective scan?

mlinmg avatar Jan 24 '24 08:01 mlinmg

I've no idea why there would be a memory spike, if you figure out let me know.

tridao avatar Jan 24 '24 09:01 tridao

Ok cool!

mlinmg avatar Jan 24 '24 09:01 mlinmg