gmm-torch icon indicating copy to clipboard operation
gmm-torch copied to clipboard

Optimized memory usage and speed for covar type "full"

Open DeMoriarty opened this issue 3 years ago • 0 comments

Improved speed and memory usage with following optimizations:

  1. the (N, K, 1, D) * (1, K, D, D) matmul at line275 is replaced with an equivalent matmul (K, N, D) * (K, D, D). (N, K, 1, D) * (1, K, D, D) will be interpreted by cublas as batched matrix vector product, while (K, N, D) * (K, D, D) is batched matrix matrix product, which is more efficient on GPUs.

  2. in 2 consecutive iterations of fit, _estimate_log_prob was being called twice with the same input, in _e_step and __score. now weighted_log_probs is only computed once in __score of previous iteration, then cached to be reused at _e_step of next iteration.

  3. at line342 , mu was originally obtained by element wise multiplication & summation, which is now simplified to a matmul.

  4. at line346, the batched vector outer product followed by summation is rewritten as a single batched matmul, which is more efficient on GPUs.

  5. computations in _m_step and _estimate_log_prob is splitted into smaller "chunks" of computations in order to prevent OOM as much as possible.

  6. added option to choose the dtype of the covariance matrix. Use torch.linalg.eigvals to compute log_det if covariance_data_type = torch.float, otherwise use cholesky decomp.

  7. replaced some of the tensor-scalar or tensor-tensor additions/multiplications with their inplace counterparts to reduce unnecessary memory allocation.

benchmark results

remaining issues:

  1. when covariance_data_type = "float", and both n_components and n_features are large, covar contains NaN.

DeMoriarty avatar Mar 13 '22 19:03 DeMoriarty