candle icon indicating copy to clipboard operation
candle copied to clipboard

No backward pass for `RmsNorm` if tensor is contiguous

Open agerasev opened this issue 1 year ago • 0 comments

RmsNorm switches to faster implementation if tensor is contiguous:

https://github.com/huggingface/candle/blob/82b641fd2752e3b14db6a9c91faef70e3329f3b5/candle-nn/src/layer_norm.rs#L174-L175

But it does not support backward pass:

https://github.com/huggingface/candle/blob/82b641fd2752e3b14db6a9c91faef70e3329f3b5/candle-nn/src/ops.rs#L640

Maybe it's better to implement ModuleT rather than Module for RmsNorm and use faster implementation only if train == false?

agerasev avatar May 07 '24 07:05 agerasev