NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

More batched functions such as batched_svd, batched_diagm

Open sychen52 opened this issue 3 years ago • 7 comments

In order to write a multiple view geometry and make it easy to used in deep learning, I think the input and output tensor should have a batch dimension. And I need a few batched versions of the linear algebra functions, such as torch.bmm, torch.svd, torch.diag_embd. I traced into the NNlib module and noticed it has batched_mul, batched_transpose/adjoint, but not svd, diagm.

Is this the correct place to add these batched version functions?

sychen52 avatar Mar 07 '22 15:03 sychen52

Or is there a more elegant solution than writing all these batched functions and adding them in NNlib?

sychen52 avatar Mar 07 '22 15:03 sychen52

@Roger-luo was going to collect many of these in https://github.com/Roger-luo/BatchedRoutines.jl but that was a while ago. Otherwise here sounds OK to me.

batched_mul got pretty complicated as it wanted to allow all PermutedDimsArray which could use BLAS to do so, possibly a mistake. If batched_svd is less ambitious it could be much simpler. But what would it return?

Are there batched CUDA versions of these functions?

mcabbott avatar Mar 07 '22 16:03 mcabbott

There's a few batched linear algebra functions in magma which we tried to wrap in Julia but had issue with BB. See Magma.jl

Roger-luo avatar Mar 07 '22 21:03 Roger-luo

Am I understanding this correctly? In order to have batched_svd in NNlib, 1) we need a batched cpu version using LAPACK (maybe put in BatchedRoutines.jl), 2) we also need a batched cuda version using MAGMA (put in Magma.jl), 3) then we unify the function api in NNlib by calling cpu and gpu version underneath.

sychen52 avatar Mar 08 '22 06:03 sychen52

Judging by https://github.com/pytorch/pytorch/blob/5dbec7c07c5eedd748fd56359c2d1b980dfa1037/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp, the magma dep may not be required for GPU support.

ToucheSir avatar Mar 08 '22 15:03 ToucheSir

@ToucheSir is there any example code you could point to for starting on the path of writing batched_svd / how would you recommend approaching it?

nikopj avatar Feb 08 '23 04:02 nikopj

I'm almost certainly the one with the least linear algebra knowledge on this thread and thus not the one you want to ask such questions :)

ToucheSir avatar Feb 08 '23 05:02 ToucheSir