nflows
nflows copied to clipboard
If my data is 6 dimension, how can I use this code to process it?
I wanna use this code to predict stocks, but my data shape is (n,6),the example moons has only dimension 2, I can't do this:
xline = torch.linspace(-1.5, 2.5, 100)
yline = torch.linspace(-.75, 1.25, 100)
xgrid, ygrid = torch.meshgrid(xline, yline)
xyinput = torch.cat([xgrid.reshape(-1, 3), ygrid.reshape(-1, 3)], dim=1)
with torch.no_grad():
zgrid = flow.log_prob(xyinput).exp().reshape(100, 100)
plt.contourf(xgrid.numpy(), ygrid.numpy(), zgrid.numpy())
plt.title('iteration {}'.format(i + 1))
plt.show()