pytorch-struct
pytorch-struct copied to clipboard
FastLogSemiring
Hi,
Thanks for making this library and it's amazing to have these different CRFs wrapped up in a common and easy to use framework.
I've been playing with the LinearChainCRF and one thing I noticed is the memory usage can be very high during loss backward pass on both CPU and GPU. I found the FastLogSemiring in fast_semirings.py uses genbmm.logbmm() and significantly reduce memory usage on GPU if I change the default LogSemiring used in StructDistribution class to FastLogSemiring. However, I haven't seen this being documented anywhere so my questions are:
- Is
FastLogSemiringready to be used? It's not being included intest_semirings.py - If so, what would be the best way to switch between
LogSemiringandFastLogSemiring? Is there a plan to introduce a parameter to choose between the semirings inStructDistributionclass?
Yes! It works and is heavily tested. We should make it default. It just requires the GPU kernels in genbmm be installed.
What do you think of performing a check of genbmm library in the imports like:
has_genbmm = False
try:
import genbmm
has_genbmm = True
from .semirings import FastLogSemiring
except ImportError:
pass
then a function in StructDistribution class:
def default_log_semiring(self):
return FastLogSemiring if has_genbmm and self.log_potentials.is_cuda else LogSemiring
So instead of return LogSemiring by default in marginals and partition property we call this default_log_semiring()
yes, that would be great. You can do it for max too.