GANLatentDiscovery
GANLatentDiscovery copied to clipboard
Fix incorrect dimensions when loading the shift predictor
When loading the shift predictor, its dimensions were not being loaded correctly.
It does not match the directions of the latent deformator. For example, with a deformator of 64 dimensions, by default the shift predictor was initialized with 128 dimensions. This would not fail during training because of pytorch's use of both one-hots and integers in the cross-entropy loss of the predictor.
This small change in setting the shift predictor's dim explicitly to directions_count
fixes this.
Correct me if my interpretation is wrong