pyDVL icon indicating copy to clipboard operation
pyDVL copied to clipboard

EkfacInfluence breaks for large LinearLayer

Open jakobkruse1 opened this issue 1 year ago • 3 comments

Due to an issue in pytorch, EkfacInfluence does not work for linear layers with more than 2895 input or output size. This is caused by an issue in torch::linalg::eigh which leads to some 32bit integer overflow. The only fix is to use 64bit, pytorch is working on that AFAIK, see the issue here.

In EkfacInfluence, we are using

evals_a, evecs_a = torch.linalg.eigh(forward_x[key])
evals_g, evecs_g = torch.linalg.eigh(grad_y[key])

with forward_x and grad_y coming from _init_layer_kfac_blocks:

sG = module.out_features
sA = module.in_features + int(with_bias)
forward_x_layer = torch.zeros((sA, sA), device=module.weight.device)
grad_y_layer = torch.zeros((sG, sG), device=module.weight.device)

Basically, the matrix torch.zeros((sA, sA), ...) is just too big to work with torch.linalg.eigh. This then raises a super ugly error message:

RuntimeError: false INTERNAL ASSERT FAILED at ".../aten/src/ATen/native/BatchLinearAlgebra.cpp":1462, 
please report a bug to PyTorch. linalg.eigh: Argument 8 has illegal value. 
Most certainly there is a bug in the implementation calling the backend library.

Proposed fix: My proposed fix for EkfacInfluence is to check the sG and sA variables and make sure they are small enough not to run into the 32bit integer problem. This also affected the influence_sentiment_analysis notebook, but I already took the large linear layers out of Ekfac in another PR #550 .

jakobkruse1 avatar Apr 02 '24 13:04 jakobkruse1

@schroedk . Comments? Update: Seems to be MPS-specific, or at least not a problem with (my version of) cuda

mdbenito avatar Apr 03 '24 09:04 mdbenito

@jakobkruse1 I would not suggest, that we modify the code for every architecture. I would rather try except on the RuntimeError and make a meaningful error message, maybe even referencing the torch issue. What do you think?

schroedk avatar Apr 08 '24 09:04 schroedk

I agree @schroedk. It seems like this is an MKL-specific error, it was also working on my CUDA machine. try except seems like the easiest way to solve this quickly and wait for torch to fix this.

jakobkruse1 avatar Apr 10 '24 08:04 jakobkruse1