gmm-torch
gmm-torch copied to clipboard
Optimized memory usage and speed for covar type "full"
Improved speed and memory usage with following optimizations:
-
the
(N, K, 1, D) * (1, K, D, D)matmul at line275 is replaced with an equivalent matmul(K, N, D) * (K, D, D).(N, K, 1, D) * (1, K, D, D)will be interpreted by cublas as batched matrix vector product, while(K, N, D) * (K, D, D)is batched matrix matrix product, which is more efficient on GPUs. -
in 2 consecutive iterations of
fit,_estimate_log_probwas being called twice with the same input, in_e_stepand__score. nowweighted_log_probsis only computed once in__scoreof previous iteration, then cached to be reused at_e_stepof next iteration. -
at line342 ,
muwas originally obtained by element wise multiplication & summation, which is now simplified to a matmul. -
at line346, the batched vector outer product followed by summation is rewritten as a single batched matmul, which is more efficient on GPUs.
-
computations in
_m_stepand_estimate_log_probis splitted into smaller "chunks" of computations in order to prevent OOM as much as possible. -
added option to choose the dtype of the covariance matrix. Use
torch.linalg.eigvalsto computelog_detifcovariance_data_type = torch.float, otherwise use cholesky decomp. -
replaced some of the tensor-scalar or tensor-tensor additions/multiplications with their inplace counterparts to reduce unnecessary memory allocation.
remaining issues:
- when
covariance_data_type = "float", and bothn_componentsandn_featuresare large, covar contains NaN.