pytorch-struct icon indicating copy to clipboard operation
pytorch-struct copied to clipboard

FastLogSemiring

Open w-cheng opened this issue 4 years ago • 3 comments
trafficstars

Hi,

Thanks for making this library and it's amazing to have these different CRFs wrapped up in a common and easy to use framework.

I've been playing with the LinearChainCRF and one thing I noticed is the memory usage can be very high during loss backward pass on both CPU and GPU. I found the FastLogSemiring in fast_semirings.py uses genbmm.logbmm() and significantly reduce memory usage on GPU if I change the default LogSemiring used in StructDistribution class to FastLogSemiring. However, I haven't seen this being documented anywhere so my questions are:

  1. Is FastLogSemiring ready to be used? It's not being included in test_semirings.py
  2. If so, what would be the best way to switch between LogSemiring and FastLogSemiring? Is there a plan to introduce a parameter to choose between the semirings in StructDistribution class?

w-cheng avatar Sep 30 '21 12:09 w-cheng

Yes! It works and is heavily tested. We should make it default. It just requires the GPU kernels in genbmm be installed.

srush avatar Sep 30 '21 17:09 srush

What do you think of performing a check of genbmm library in the imports like:

has_genbmm = False
try:
    import genbmm

    has_genbmm = True
    from .semirings import FastLogSemiring
except ImportError:
    pass

then a function in StructDistribution class:

    def default_log_semiring(self):
        return FastLogSemiring if has_genbmm and self.log_potentials.is_cuda else LogSemiring

So instead of return LogSemiring by default in marginals and partition property we call this default_log_semiring()

w-cheng avatar Oct 05 '21 16:10 w-cheng

yes, that would be great. You can do it for max too.

srush avatar Oct 06 '21 00:10 srush