burn
burn copied to clipboard
[draft] linear_change_weight_dim_order
Dirty PR here, not meant to merge
@nathanielsimard @antimora I benchmarked linear without (Linear) and with (LinearT) the new weight order change. It appears to be in general quite slower when transposing. Local median results shown below:
Weights of 512x512
- Tch-CPU (did not do gpu cause i had a default device bug)
- Linear: 59.638 ms
- LinearT: 70.217 ms
- WGPU
- Linear: 86.077 ms
- LinearT: 87.300 ms
- NdArray
- Linear: 308.077 ms
- LinearT: 1.459 s
- Candle-CPU
- Linear: 420.805 ms
- LinearT: 769.235 ms
Weights of 1024x1024
- Tch-CPU
- Linear: 182.468 ms
- LinearT: 194.116 ms
- WGPU
- Linear: 288.025 ms
- LinearT: 294.151 ms
Weights of 2048x2048
- Tch-CPU
- Linear: 852.782 ms
- LinearT: 705.922 ms
- WGPU
- Linear: 1.053 s
- LinearT: 1.081 s
Interestingly, Tch becomes faster on LinearT for large matmuls, and WGPU is never that bad on LinearT
Thank you for running numbers. It seems it's slower. Perhaps, we should use more specialized op? For example, torch uses torch.addmm(self.bias, x, self.weight.t())
instead of matmul (https://github.com/pytorch/pytorch/blob/1474eb5f293215c995b7c224ef924c305b166be1/torch/nn/modules/linear.py#L116)
Thank you, @louisfd for your research. This helped me go another route and leave nn modules as they're. In my PR (https://github.com/tracel-ai/burn/pull/1085), I am using custom deserializer with an adapter, which changes tensors/names before loading.
Thank you, @louisfd for your research. This helped me go another route and leave nn modules as they're. In my PR (#1085), I am using custom deserializer with an adapter, which changes tensors/names before loading.
Glad we could work that out 😄