heat
heat copied to clipboard
Rework orthogonalization in Lanczos without direct Torch access
The dndarray __getitem__
function appeared to be rather slow in the Gram-Schmitt orthogonalization of the Lanczos algorithm (linalg.py):
for i in Lanczos_Iterations:
...
for j in range(i)
vr = vr - projection(vr, V[ : , j ])
--> ca 60s for 300 iterations for __getitem__
Using only the item access in each of the 300*(300+1)/2 iterations yielded a computation time of roughly 40s
for i in Lanczos_Iterations:
...
for j in range(i)
temp = V[ : , j ]
Also the projection(a,b)
function used item access and was rather slow.
projection(a,b) = (dot(a, b) / dot(b, b)) * b
To speed this up, an arithmetic solution was implemented, that calculates parts of the dot-products on the the local torch level and uses allreduce to gather them. However, this is not nice at it reveals low-level communication to the user. We should think of a way to abstract this. @Markus-Goetz @coquelin77