Great package -- but code quality could be improved
Dear all,
thank you very much for this repo and for implementing this -- I've been enjoying using it. It's great to see well-documented and tested (!) code being written for scientific research.
However, to enable users to adapt this to their own purposes, I had some ideas how the underlying code could be improved and cleaned up a little. Here are some things I noticed while working with it:
- At the moment, too many functions are passed parameters that should be fixed class elements, such as the
metricor thesampler: these do not change between setups and do not need to be passed as args. Such (static) parameters should also not be checked in functions; e.g. insamplers.py, line 123 (fishfunction) you're raising
This shouldn't happen here:raise ValueError('Unknown metric: {}'.format(metric))metricshould be a class member that is checked upon instantiation of the class, and never again thereafter. - the
leapfrogfunction is too long and should be broken down into smaller sub-functions, which are more legible and easier to understand. For instance, just have seperate functions for each step, which are then called in the main routine:
Note here howdef _Leapfrog_HMC(self, x: torch.Tensor, p: torch.Tensor, *, inv_mass: Union[torch.tensor, Sequence] = None ) -> tuple[torch.Tensor, torch.Tensor]: """ Leapfrog integrator. Integrates Hamilton's equations for L steps. :param x: current sample position :param p: current momentum :param inv_mass: (optional) inverse mass matrix to be used :return: updated position and momentum """ x = x.clone(); p = p.clone() _new_x, _new_p = [], [] # Update the momentum in the direction of the gradient p += 0.5 * self.step_size * self.grad_of_sample(x) # Update the position for L steps, using the mass matrix, if given for n in range(self.L): # If no inverse mass given, update the samples using Hamiltonian dynamics if inv_mass is None: x += self.step_size * p else: if isinstance(inv_mass, Sequence): x += self.step_size * _multiply_with_block_matrix(inv_mass, p) elif len(inv_mass.shape) == 2: x += self.step_size * torch.matmul(inv_mass, p.view(-1,1)).view(-1) else: x += self.step_size * inv_mass * p # Update the momentum: only need last for Hamiltonian check; # see https://arxiv.org/pdf/1206.1901.pdf, p. 14. p += self.step_size * self.grad_of_sample(x) if n < self.L -1 else 0.5 * self.step_size * self.grad_of_sample(x) # Store the updated position and momentum vector _new_x.append(x.clone().detach()) _new_p.append(p.clone().detach()) return torch.Tensor(_new_x), torch.Tensor(_new_p) ``` And then later on you do ```python def Leapfrog(x, p, *args, **kwargs): if self.sampler == Sampler.HMC and self.integrator in [Integrator.EXPLICIT, Integrator.IMPLICIT, Integrator.S3]: return self._Leapfrog_HMC(x, p, inv_mass=inv_mass)self.Landself.step_sizeare members of the sampler class, not passed to the function as args. - why don't you make multiplication with a block matrix a separate function, since that code snippet comes up a few times (whenever you're multiplying the momentum with the inverse mass, basically):
from typing import Sequence def _multiply_with_block_matrix(A: Sequence, x: torch.Tensor) -> torch.Tensor: """ Multiplies a vector ``x`` with a block matrix ``A``, given as a Sequence of blocks""" _i = 0 _new_x = torch.zeros_like(x) for block in A: _j = block[0].shape[0] _new_x[_i:_i + _j] = torch.matmul(block, x[_i:_i + _j].view(-1, 1)).view(-1) _i += _j return _new_x - the
utils.has_nan_or_inffunction can be simplified:def _has_nan_or_inf(val: torch.Tensor) -> bool: return torch.isnan(val.view(-1, 1)) or torch.isinf(val.view(-1, 1)) or torch.isneginf(val.view(-1, 1)) - In general I would encourage you to write object-oriented code, thinking about which constants and variables can be class members. Almost all functions are being passed args like
jitter,sampler,integrator,metric,softabs_const, etc. This makes the signatures long, is repetitive and unnecessary, and just makes the code harder to read; for instance, insamplers.py::sample, you have
which could be condensed down tomomentum = gibbs(params, sampler=sampler, log_prob_func=log_prob_func, jitter=jitter, normalizing_const=normalizing_const, softabs_const=softabs_const, metric=metric, mass=mass) ham = hamiltonian(params, momentum, log_prob_func, jitter=jitter, softabs_const=softabs_const, explicit_binding_const=explicit_binding_const, normalizing_const=normalizing_const, sampler=sampler, integrator=integrator, metric=metric, inv_mass=inv_mass) leapfrog_params, leapfrog_momenta = leapfrog(params, momentum, log_prob_func, sampler=sampler, integrator=integrator, steps=num_steps_per_sample, step_size=step_size, inv_mass=inv_mass, jitter=jitter, jitter_max_tries=jitter_max_tries, fixed_point_threshold=fixed_point_threshold, fixed_point_max_iterations=fixed_point_max_iterations, normalizing_const=normalizing_const, softabs_const=softabs_const, explicit_binding_const=explicit_binding_const, metric=metric, store_on_GPU=store_on_GPU, debug=debug, pass_grad=pass_grad)
and all other args and kwargs are just fixed class members (e.g. themomentum = gibbs(momentum) ham = hamiltonian(params) leapfrog_params, leapfrog_momenta = leapfrog(params, momentum)massorinv_massdon't need to be passed to functions all the time). This immediately makes the code much shorter and much cleaner. - You are repeatedly calculating the log probability
log_probwhen this in fact could be stored as a class member and shared and updated across functions, thereby cutting down the compute time. - In Python, you can declare variable types and return types in the function signature:
I would encourage you to do this, it makes it much easier to understand and debug what each function wants and what it returns. This would also help you cut down the docstrings a little.def func(a: torch.Tensor, b: Union[float, torch.Tensor]) -> torch.Tensor: pass - I think the
samplefunction should not takenum_samplesas a an argument: rather, it should do a single step, so that it can be called in a loop:
Here again, the advantage of using a class is obvious: thefor n in range(num_samples): HMC.sample(*args, **kwargs) # Can do other things here ...samplefunction should update the class internal parameters, like the current position and momentum etc., and perhaps return the sample, rather than just outputting a really long list of samples at the end. - The code is not PEP-conform: you don't need to pepify it manually, you could run a code formatter like black to do it for you
I might update this list later on as I keep working with it :) But thanks again for your great work!
Thanks very much for your comments! Please feel free to make suggested pull requests and I can check if they make sense and then merge them. If I get time, I can try and make some updates, but this is definitely an open-source project, so definitely don't hold back!
I appreciate the time you have spent looking into this code.
Best,
Adam