disent icon indicating copy to clipboard operation
disent copied to clipboard

visualisation with trained model

Open dgm2 opened this issue 3 years ago • 3 comments

Hi, great package!

I am looking at the example in plotting_examples folder. These seem to work independently from a trained torch model ? what would be a minimal way / example to use those with a trained model ? e.g. how to visualise the latent traversal of a trained model

Best regards

dgm2 avatar Jan 02 '23 10:01 dgm2

Hi @dgm2, thank you!

My apologies for the delayed response. Sorry you are right. Most of them are independent.

However, you can have a look at the helper code inside the pytorch lightning utilities, one of the callback classes is specifically for generating latent traversals:

https://github.com/nmichlo/disent/blob/ff462ba567a734041874cc584b97695a81729498/disent/util/lightning/callbacks/_callback_vis_latents.py#L197-L261

This method uses helper functions from disent.util.visualize.vis_img to convert tensors to images, disent.util.visualize.vis_latents to generate latent sequences, and disent.util.visualize.vis_util to combine the images together into a grid or make sequential frames.

The code is more complicated than it need to be for most cases because of some additional handling and quirks. Maybe we can add a specific docs example for latent traversals.

nmichlo avatar Jan 05 '23 12:01 nmichlo

thanks! the callback method returns stills, frames, image how should I input these into plot_dataset_traversals ? or into visualize_dataset_traversal or what is are corresponding values there? e.g. does stills corresponds to grid as input into plt_subplots_imshow

e.g. this example makes sense? many thanks!

trainer = pl.Trainer(
    max_steps=2048,
    gpus=1 if torch.cuda.is_available() else None,
    logger=False,
    checkpoint_callback=False,
    max_epochs=1
)
trainer.fit(module, dataloader)
# trainer.save_checkpoint("trained.ckpt")

viz = VaeLatentCycleLoggingCallback()
stills, frames_, image_ = viz.generate_visualisations(trainer_or_dataset=trainer, pl_module=trainer.lightning_module,
                                                      num_frames=4, num_stats_samples=15)

plt_scale = 4.5
offset = 0.75
factors, frames, _, _, c = stills.shape

plt_subplots_imshow(grid=stills, title=None, row_labels=None, subplot_padding=None,
                    figsize=(offset + (1 / 2.54) * frames * plt_scale, (1 / 2.54) * (factors + 0.45) * plt_scale),
                    show=False)


dgm2 avatar Jan 08 '23 14:01 dgm2

Your example makes sense, but admittedly it has been a while since I last touched the code (I realize the current system is not optimal for these custom scripts, so this will need to be fixed in future).

  • stills is should be an array of shape (num_latents, num_frames, 64, 64, 3) containing individual latent traversals.
  • frames is a concatenated version of stills intended to create videos, so the individual stills over the factors dimension are combined together into an image grid. The final array is approx of shape (num_frames, ~(64 * grid_h), ~(64 * grid_w), 3).
  • image is a single image that you can plot that has all the latent traversals merged together into a grid, the x axis of this grid will correspond to num_latents and y axis to num_frames (or vice versa) so the shape will be approx: (~(64 * num_latents), ~(64 * num_frames), 3)

You can try and plot images directly with plt.imshow(image). Or create your own visualization/animation with the frames or stills

nmichlo avatar Jan 11 '23 07:01 nmichlo