pyro icon indicating copy to clipboard operation
pyro copied to clipboard

Spline inverse

Open guyko81 opened this issue 3 years ago • 3 comments

Hi,

I can't find the proper way to make a Spline transform inverse. What I would like to do is a 2 step approach:

  1. train a Spline transform from a 1D Standard Normal distribution on a data
base_dist = dist.Normal(torch.zeros(1), torch.ones(1))
spline_transform= Spline(1, count_bins=4, bound=3.)
flow_dist = dist.TransformedDistribution(base_dist, [spline_transform])
loss = -flow_dist.log_prob(torch.tensor(Y_train.reshape(-1,1), dtype=torch.float)).mean()
  1. get the "more Normal" version of the target variable, And there is my problem, I don't know how to inverse the Spline transform and run through the Y_train on it to get back the "normalized" version of it. I would think something like that, but please help me - couldn't find such a thing in the documentation.
normalized_Y_train = transform.inv(Y_train)

Apparently the parameters for transform.inv is not implemented. The following throws an error: AttributeError: '_InverseTransform' object has no attribute 'parameters'

for param in spline_transform.inv.parameters():
    if param.requires_grad:
        print(param.data)

While the normal version (without the .inv) of it prints out the parameters. What am I missing?

guyko81 avatar Jun 04 '21 09:06 guyko81

cc @stefanwebb

fritzo avatar Jun 04 '21 13:06 fritzo

@guyko81 the .inv method swaps the forward and inverse operations and returns a new transform, but it keeps the same parameters. Have you tried called transform._inverse(Y_train) on the original transform?

stefanwebb avatar Jun 07 '21 02:06 stefanwebb

@stefanwebb no, I wasn't aware of the function name, thanks for sharing! I'll check and get back if it doesn't work, thank you!

guyko81 avatar Jun 07 '21 13:06 guyko81