visualisation with trained model
Hi, great package!
I am looking at the example in plotting_examples folder. These seem to work independently from a trained torch model ? what would be a minimal way / example to use those with a trained model ? e.g. how to visualise the latent traversal of a trained model
Best regards
Hi @dgm2, thank you!
My apologies for the delayed response. Sorry you are right. Most of them are independent.
However, you can have a look at the helper code inside the pytorch lightning utilities, one of the callback classes is specifically for generating latent traversals:
https://github.com/nmichlo/disent/blob/ff462ba567a734041874cc584b97695a81729498/disent/util/lightning/callbacks/_callback_vis_latents.py#L197-L261
This method uses helper functions from disent.util.visualize.vis_img to convert tensors to images, disent.util.visualize.vis_latents to generate latent sequences, and disent.util.visualize.vis_util to combine the images together into a grid or make sequential frames.
The code is more complicated than it need to be for most cases because of some additional handling and quirks. Maybe we can add a specific docs example for latent traversals.
thanks!
the callback method returns stills, frames, image
how should I input these into plot_dataset_traversals ? or into visualize_dataset_traversal
or what is are corresponding values there? e.g. does stills corresponds to grid as input into plt_subplots_imshow
e.g. this example makes sense? many thanks!
trainer = pl.Trainer(
max_steps=2048,
gpus=1 if torch.cuda.is_available() else None,
logger=False,
checkpoint_callback=False,
max_epochs=1
)
trainer.fit(module, dataloader)
# trainer.save_checkpoint("trained.ckpt")
viz = VaeLatentCycleLoggingCallback()
stills, frames_, image_ = viz.generate_visualisations(trainer_or_dataset=trainer, pl_module=trainer.lightning_module,
num_frames=4, num_stats_samples=15)
plt_scale = 4.5
offset = 0.75
factors, frames, _, _, c = stills.shape
plt_subplots_imshow(grid=stills, title=None, row_labels=None, subplot_padding=None,
figsize=(offset + (1 / 2.54) * frames * plt_scale, (1 / 2.54) * (factors + 0.45) * plt_scale),
show=False)
Your example makes sense, but admittedly it has been a while since I last touched the code (I realize the current system is not optimal for these custom scripts, so this will need to be fixed in future).
-
stillsis should be an array of shape(num_latents, num_frames, 64, 64, 3)containing individual latent traversals. -
framesis a concatenated version ofstillsintended to create videos, so the individual stills over the factors dimension are combined together into an image grid. The final array is approx of shape(num_frames, ~(64 * grid_h), ~(64 * grid_w), 3). -
imageis a single image that you can plot that has all the latent traversals merged together into a grid, the x axis of this grid will correspond tonum_latentsand y axis tonum_frames(or vice versa) so the shape will be approx:(~(64 * num_latents), ~(64 * num_frames), 3)
You can try and plot images directly with plt.imshow(image). Or create your own visualization/animation with the frames or stills