pyro
pyro copied to clipboard
Spline inverse
Hi,
I can't find the proper way to make a Spline transform inverse. What I would like to do is a 2 step approach:
- train a Spline transform from a 1D Standard Normal distribution on a data
base_dist = dist.Normal(torch.zeros(1), torch.ones(1))
spline_transform= Spline(1, count_bins=4, bound=3.)
flow_dist = dist.TransformedDistribution(base_dist, [spline_transform])
loss = -flow_dist.log_prob(torch.tensor(Y_train.reshape(-1,1), dtype=torch.float)).mean()
- get the "more Normal" version of the target variable, And there is my problem, I don't know how to inverse the Spline transform and run through the Y_train on it to get back the "normalized" version of it. I would think something like that, but please help me - couldn't find such a thing in the documentation.
normalized_Y_train = transform.inv(Y_train)
Apparently the parameters for transform.inv is not implemented. The following throws an error: AttributeError: '_InverseTransform' object has no attribute 'parameters'
for param in spline_transform.inv.parameters():
if param.requires_grad:
print(param.data)
While the normal version (without the .inv) of it prints out the parameters. What am I missing?
cc @stefanwebb
@guyko81 the .inv
method swaps the forward and inverse operations and returns a new transform, but it keeps the same parameters. Have you tried called transform._inverse(Y_train)
on the original transform?
@stefanwebb no, I wasn't aware of the function name, thanks for sharing! I'll check and get back if it doesn't work, thank you!