GPJax icon indicating copy to clipboard operation
GPJax copied to clipboard

dev: Transformations.

Open daniel-dodd opened this issue 3 years ago • 1 comments

Perhaps we should not require parameter transformations (via transform) in objective functions - they should rest with model training.

If I have an ELBO or the marginal log-likelihood, shouldn't I just be able to pass my parameters to it without transforming anything?

For example, for the marginal log-likelihood of GP regression, in the fit abstraction we currently have to objective function as objective = posterior.mll(D, transformation, negative=True) and abstractions.py defines a loss function to train the parameters (that stops gradients):

def loss(params):
        params = trainable_params(params, trainables)
        return objective(params)

Perhaps it would be nicer to define an objective with objective = posterior.mll(D, negative=True) (i.e. no transforms specified) and then have the transform in the training loop instead e.g.,

def loss(params): 
        params = trainable_params(params, trainables)
        params = transform(params, transform)
        return objective(params)

The training loop could even possibly take a bijector argument and abstractions.py could manage forward and reverse transformations (gpx.initialise could return a dictionary of bijectors instead of the constrainer and unconstrainer convention).

daniel-dodd avatar Aug 25 '22 15:08 daniel-dodd

On a related note, it seems that the log-det-jacobian term is missing from the objective function that is passed to HMC. (At least I could not see the string 'jacfwd' or 'jacrev' anywhere in the codebase :) The importance of this term is illustrated in https://github.com/probml/pyprobml/blob/master/notebooks/book2/03/change_of_variable_hmc.ipynb

murphyk avatar Sep 09 '22 22:09 murphyk

Thanks for spotting this @murphyk (and apologies for the slow reply)! We recently refactored GPJax, first by removing transformations from all objects in v0.5.0. Then, our most recent release v0.5.2 removes priors from the "marginal likelihood", leaving it as a probability transition kernel only (#153). Priors transformations are now left to the users, and could for example be handled via transformed.dist as demonstrated in the tfp integration notebook example.

daniel-dodd avatar Nov 30 '22 15:11 daniel-dodd