hamiltorch icon indicating copy to clipboard operation
hamiltorch copied to clipboard

Great package -- but code quality could be improved

Open ThGaskin opened this issue 2 years ago • 1 comments

Dear all,

thank you very much for this repo and for implementing this -- I've been enjoying using it. It's great to see well-documented and tested (!) code being written for scientific research.

However, to enable users to adapt this to their own purposes, I had some ideas how the underlying code could be improved and cleaned up a little. Here are some things I noticed while working with it:

  • At the moment, too many functions are passed parameters that should be fixed class elements, such as the metric or the sampler: these do not change between setups and do not need to be passed as args. Such (static) parameters should also not be checked in functions; e.g. in samplers.py, line 123 (fish function) you're raising
    raise ValueError('Unknown metric: {}'.format(metric))
    
    This shouldn't happen here: metric should be a class member that is checked upon instantiation of the class, and never again thereafter.
  • the leapfrog function is too long and should be broken down into smaller sub-functions, which are more legible and easier to understand. For instance, just have seperate functions for each step, which are then called in the main routine:
    def _Leapfrog_HMC(self,
                x: torch.Tensor,
                p: torch.Tensor,
                *,
                inv_mass: Union[torch.tensor, Sequence] = None
    ) -> tuple[torch.Tensor, torch.Tensor]:
    
       """ Leapfrog integrator. Integrates Hamilton's equations for L steps.
    
       :param x: current sample position
       :param p: current momentum
       :param inv_mass: (optional) inverse mass matrix to be used
       :return: updated position and momentum
       """
    
       x = x.clone(); p = p.clone()
    
       _new_x, _new_p = [], []
    
       # Update the momentum in the direction of the gradient
       p += 0.5 * self.step_size * self.grad_of_sample(x)
    
       # Update the position for L steps, using the mass matrix, if given
       for n in range(self.L):
    
           # If no inverse mass given, update the samples using Hamiltonian dynamics
           if inv_mass is None:
               x += self.step_size * p
           else:
               if isinstance(inv_mass, Sequence):
                   x += self.step_size * _multiply_with_block_matrix(inv_mass, p)
               elif len(inv_mass.shape) == 2:
                   x += self.step_size * torch.matmul(inv_mass, p.view(-1,1)).view(-1)
               else:
                   x += self.step_size * inv_mass * p
    
           # Update the momentum: only need last for Hamiltonian check;
           # see https://arxiv.org/pdf/1206.1901.pdf, p. 14.
           p += self.step_size * self.grad_of_sample(x) if n < self.L -1 else 0.5 * self.step_size * self.grad_of_sample(x)
    
           # Store the updated position and momentum vector
           _new_x.append(x.clone().detach())
           _new_p.append(p.clone().detach())
    
       return torch.Tensor(_new_x), torch.Tensor(_new_p)
       ```
       And then later on you do 
       ```python
      def Leapfrog(x, p, *args, **kwargs):
    
         if self.sampler == Sampler.HMC and self.integrator in [Integrator.EXPLICIT, Integrator.IMPLICIT, Integrator.S3]:
             return self._Leapfrog_HMC(x, p, inv_mass=inv_mass)
    
    Note here how self.L and self.step_size are members of the sampler class, not passed to the function as args.
  • why don't you make multiplication with a block matrix a separate function, since that code snippet comes up a few times (whenever you're multiplying the momentum with the inverse mass, basically):
    from typing import Sequence
    def _multiply_with_block_matrix(A: Sequence, x: torch.Tensor) -> torch.Tensor:
       """ Multiplies a vector ``x`` with a block matrix ``A``, given as a Sequence of blocks"""
     _i = 0
     _new_x = torch.zeros_like(x)
     for block in A:
         _j = block[0].shape[0]
         _new_x[_i:_i + _j] = torch.matmul(block, x[_i:_i + _j].view(-1, 1)).view(-1)
         _i += _j
     return _new_x
    
  • the utils.has_nan_or_inf function can be simplified:
    def _has_nan_or_inf(val: torch.Tensor) -> bool:
       return torch.isnan(val.view(-1, 1)) or torch.isinf(val.view(-1, 1)) or torch.isneginf(val.view(-1, 1))
    
  • In general I would encourage you to write object-oriented code, thinking about which constants and variables can be class members. Almost all functions are being passed args like jitter, sampler, integrator, metric, softabs_const, etc. This makes the signatures long, is repetitive and unnecessary, and just makes the code harder to read; for instance, in samplers.py::sample, you have
     momentum = gibbs(params, sampler=sampler, log_prob_func=log_prob_func, jitter=jitter,
                      normalizing_const=normalizing_const, softabs_const=softabs_const, metric=metric,
                      mass=mass)
    
     ham = hamiltonian(params, momentum, log_prob_func, jitter=jitter, softabs_const=softabs_const,
                       explicit_binding_const=explicit_binding_const, normalizing_const=normalizing_const,
                       sampler=sampler, integrator=integrator, metric=metric, inv_mass=inv_mass)
    
     leapfrog_params, leapfrog_momenta = leapfrog(params, momentum, log_prob_func, sampler=sampler,
                                                  integrator=integrator, steps=num_steps_per_sample,
                                                  step_size=step_size, inv_mass=inv_mass, jitter=jitter,
                                                  jitter_max_tries=jitter_max_tries,
                                                  fixed_point_threshold=fixed_point_threshold,
                                                  fixed_point_max_iterations=fixed_point_max_iterations,
                                                  normalizing_const=normalizing_const,
                                                  softabs_const=softabs_const,
                                                  explicit_binding_const=explicit_binding_const,
                                                  metric=metric, store_on_GPU=store_on_GPU, debug=debug,
                                                  pass_grad=pass_grad)
    
    which could be condensed down to
    momentum = gibbs(momentum)
    ham = hamiltonian(params)
    leapfrog_params, leapfrog_momenta = leapfrog(params, momentum)
    
    and all other args and kwargs are just fixed class members (e.g. the mass or inv_mass don't need to be passed to functions all the time). This immediately makes the code much shorter and much cleaner.
  • You are repeatedly calculating the log probability log_prob when this in fact could be stored as a class member and shared and updated across functions, thereby cutting down the compute time.
  • In Python, you can declare variable types and return types in the function signature:
    def func(a: torch.Tensor, b: Union[float, torch.Tensor]) -> torch.Tensor:
       pass
    
    I would encourage you to do this, it makes it much easier to understand and debug what each function wants and what it returns. This would also help you cut down the docstrings a little.
  • I think the sample function should not take num_samples as a an argument: rather, it should do a single step, so that it can be called in a loop:
    for n in range(num_samples):
       HMC.sample(*args, **kwargs)
      # Can do other things here ... 
    
    Here again, the advantage of using a class is obvious: the sample function should update the class internal parameters, like the current position and momentum etc., and perhaps return the sample, rather than just outputting a really long list of samples at the end.
  • The code is not PEP-conform: you don't need to pepify it manually, you could run a code formatter like black to do it for you

I might update this list later on as I keep working with it :) But thanks again for your great work!

ThGaskin avatar Oct 31 '23 15:10 ThGaskin

Thanks very much for your comments! Please feel free to make suggested pull requests and I can check if they make sense and then merge them. If I get time, I can try and make some updates, but this is definitely an open-source project, so definitely don't hold back!

I appreciate the time you have spent looking into this code.

Best,

Adam

AdamCobb avatar Nov 14 '23 15:11 AdamCobb