netket icon indicating copy to clipboard operation
netket copied to clipboard

[RFC] Automatic tuning of chunk size in VMC driver

Open wdphy16 opened this issue 4 years ago • 1 comments

This trick already helped me a lot when training large models. We start by setting the chunk size to a large number, and if it causes OOM, we reduce it by half each time. A good initial value can be n_samples_per_rank * hilbert.size * a multiplier, because for each sample, the number of connected configurations in the local energy is in the order of hilbert.size. Following Clemens' advice, we keep the chunk size to be a power of 2 for performance reasons.

The implementation can be

class VMCAutoChunk(VMC):
    def __init__(self, *args, **kwargs):
        init_chunk_size_multiplier = kwargs.pop("init_chunk_size_multiplier", 16)
        min_chunk_size = kwargs.pop("min_chunk_size", None)
    
        super().__init__(*args, **kwargs)

        chunk_size = self.state.chunk_size
        # If `state.chunk_size` is already set, we use that as the initial value
        if chunk_size is None:
            chunk_size = self.state.n_samples_per_rank * self._ham.hilbert.size * init_chunk_size_multiplier
        # Round up to a power of 2
        chunk_size = 2 ** int(ceil(log2(chunk_size)))
        self.state.chunk_size = chunk_size
        
        if min_chunk_size is None:
            min_chunk_size = self.state.n_samples_per_rank

    def _forward_and_backward(self):
        while True:
            try:
                return super()._forward_and_backward()
            except RuntimeError as e:
                chunk_size = self.state.chunk_size // 2
                if chunk_size < min_chunk_size:
                    warnings.warn(f"Minimum chunk size {min_chunk_size} reached")
                    raise e
                warnings.warn(f"Reducing chunk size to {chunk_size}")
                # This driver modifies `state.chunk_size` in place
                self.state.chunk_size = chunk_size

If we want to integrate this into NetKet, we can discuss more about the API. Similar tricks are already implemented in some high-level ML frameworks like toma and PyTorch Lightning.

wdphy16 avatar Dec 23 '21 21:12 wdphy16

I'm very much in favour of this and I like it.

However I'd like this to be something generic working with arbitrary MCState than only with VMC, so it can also be used with dynamics.

Maybe a function

def estimate_chunk_size(vstate, driver=None):
   if driver is not None:
      try: 
        driver.step()
      ....
    else:
...

PhilipVinc avatar Jan 03 '22 17:01 PhilipVinc