pyDVL
pyDVL copied to clipboard
EkfacInfluence breaks for large LinearLayer
Due to an issue in pytorch, EkfacInfluence does not work for linear layers with more than 2895 input or output size.
This is caused by an issue in torch::linalg::eigh which leads to some 32bit integer overflow. The only fix is to use 64bit, pytorch is working on that AFAIK, see the issue here.
In EkfacInfluence, we are using
evals_a, evecs_a = torch.linalg.eigh(forward_x[key])
evals_g, evecs_g = torch.linalg.eigh(grad_y[key])
with forward_x and grad_y coming from _init_layer_kfac_blocks:
sG = module.out_features
sA = module.in_features + int(with_bias)
forward_x_layer = torch.zeros((sA, sA), device=module.weight.device)
grad_y_layer = torch.zeros((sG, sG), device=module.weight.device)
Basically, the matrix torch.zeros((sA, sA), ...) is just too big to work with torch.linalg.eigh. This then raises a super ugly error message:
RuntimeError: false INTERNAL ASSERT FAILED at ".../aten/src/ATen/native/BatchLinearAlgebra.cpp":1462,
please report a bug to PyTorch. linalg.eigh: Argument 8 has illegal value.
Most certainly there is a bug in the implementation calling the backend library.
Proposed fix:
My proposed fix for EkfacInfluence is to check the sG and sA variables and make sure they are small enough not to run into the 32bit integer problem. This also affected the influence_sentiment_analysis notebook, but I already took the large linear layers out of Ekfac in another PR #550 .
@schroedk . Comments? Update: Seems to be MPS-specific, or at least not a problem with (my version of) cuda
@jakobkruse1 I would not suggest, that we modify the code for every architecture. I would rather try except on the RuntimeError and make a meaningful error message, maybe even referencing the torch issue. What do you think?
I agree @schroedk. It seems like this is an MKL-specific error, it was also working on my CUDA machine. try except seems like the easiest way to solve this quickly and wait for torch to fix this.