candle
candle copied to clipboard
Add masked_fill under Tensor
This unifies the masked_fill implementations under Tensor.
Addresses #2370 .
I'm not sure we want to have this under Tensor, the goal there is not to replicate all the functions in pytorch but rather have a smaller subset of basic functions. Maybe put it in candle-transformers/src/utils.rs instead.
Also note that the tensor pre-allocation for some of the models was done on purpose as it avoids a sync point on cuda which was a problem for performance so we would want to keep the current implementation for these.
@LaurentMazare thanks for the comments. I will update the PR accordingly.