OneTrainer icon indicating copy to clipboard operation
OneTrainer copied to clipboard

[Feat]: Per-batch latent distribution sampling

Open cheald opened this issue 11 months ago • 2 comments

Describe your use-case.

The expensive part of VAE encoding is the computation of the latent distribution. Actually sampling that distribution is very cheap and easy. When latent caching is enabled, only a single sample of the latent distribution is cached. Training robustness is improved by instead sampling a new sample from the latent distribution each time we want that image. Empirically, this has improved the robustness of my training run quite nicely without any significant additional computational overhead - the act of sampling the distribution is extremely cheap, but this gets us significantly more varied input data to train with, so that the network learns the distribution of your latents, rather than specific samples of those latents.

I've attempted to implement this in the dataloader itself (by caching latent_image_distribution rather than latent_image) but there's something funky with MGDS that breaks when I do that, which I have not yet been able to track down (IndexError: list index out of range in CollectPaths - something about changing what's cached is breaking counters somewhere). However, my hack below works as a proof of concept.

This is a feature request rather than a PR because the right solution is to do this in the dataloader, but that needs MGDS fixes.

What would you like to see as a solution?

I've implemented this locally by patching the data loader to expose the latent distribution:

(This is done by exposing the params because DataLoader wants tensors, not objects wrapping tensors)

diff --git a/modules/dataLoader/StableDiffusionBaseDataLoader.py b/modules/dataLoader/StableDiffusionBaseDataLoader.py
index 2767ee2..01bafe8 100644
--- a/modules/dataLoader/StableDiffusionBaseDataLoader.py
+++ b/modules/dataLoader/StableDiffusionBaseDataLoader.py
@@ -14,6 +14,7 @@ from mgds.pipelineModules.DecodeVAE import DecodeVAE
 from mgds.pipelineModules.DiskCache import DiskCache
 from mgds.pipelineModules.EncodeClipText import EncodeClipText
 from mgds.pipelineModules.EncodeVAE import EncodeVAE
+from mgds.pipelineModules.MapData import MapData
 from mgds.pipelineModules.RescaleImageChannels import RescaleImageChannels
 from mgds.pipelineModules.SampleVAEDistribution import SampleVAEDistribution
 from mgds.pipelineModules.SaveImage import SaveImage
@@ -65,6 +66,7 @@ class StableDiffusionBaseDataLoader(
         rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1)
         rescale_conditioning_image = RescaleImageChannels(image_in_name='conditioning_image', image_out_name='conditioning_image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1)
         encode_image = EncodeVAE(in_name='image', out_name='latent_image_distribution', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype())
+        dist_params = MapData(in_name='latent_image_distribution', out_name='latent_image_distribution_params', map_fn=lambda x: x.parameters.float().squeeze(0))
         image_sample = SampleVAEDistribution(in_name='latent_image_distribution', out_name='latent_image', mode='mean')
         downscale_mask = ScaleImage(in_name='mask', out_name='latent_mask', factor=0.125)
         encode_conditioning_image = EncodeVAE(in_name='conditioning_image', out_name='latent_conditioning_image_distribution', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype())
@@ -73,7 +75,7 @@ class StableDiffusionBaseDataLoader(
         tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=model.tokenizer.model_max_length)
         encode_prompt = EncodeClipText(in_name='tokens', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_hidden_state', pooled_out_name=None, add_layer_norm=True, text_encoder=model.text_encoder, hidden_state_output_index=-(1 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype())

-        modules = [rescale_image, encode_image, image_sample, tokenize_prompt]
+        modules = [rescale_image, encode_image, dist_params, image_sample, tokenize_prompt]

         if config.masked_training or config.model_type.has_mask_input():
             modules.append(downscale_mask)
@@ -92,7 +94,7 @@ class StableDiffusionBaseDataLoader(
         return modules

     def _cache_modules(self, config: TrainConfig, model: StableDiffusionModel):
-        image_split_names = ['latent_image']
+        image_split_names = ['latent_image', 'latent_image_distribution_params']

         if config.masked_training or config.model_type.has_mask_input():
             image_split_names.append('latent_mask')
@@ -150,7 +152,7 @@ class StableDiffusionBaseDataLoader(
         return modules

     def _output_modules(self, config: TrainConfig, model: StableDiffusionModel):
-        output_names = ['latent_image', 'tokens', 'image_path', 'prompt']
+        output_names = ['latent_image', 'latent_image_distribution_params', 'tokens', 'image_path', 'prompt']

         if config.masked_training or config.model_type.has_mask_input():
             output_names.append('latent_mask')

and then in GenericTrainer, I replace the latent_image:

diff --git a/modules/trainer/GenericTrainer.py b/modules/trainer/GenericTrainer.py
index 5d2dd64..7741938 100644
--- a/modules/trainer/GenericTrainer.py
+++ b/modules/trainer/GenericTrainer.py
@@ -661,8 +669,13 @@ class GenericTrainer(BaseTrainer):
                         self.model_setup.setup_train_device(self.model, self.config)

                 self.callbacks.on_update_status("training")

                 with TorchMemoryRecorder(enabled=False):
+                    if "latent_image_distribution_params" in batch:
+                        batch["latent_image"] = generate_latents_from_dist(batch["latent_image_distribution_params"], train_progress.global_step)
+
                     model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
@@ -774,3 +794,13 @@ class GenericTrainer(BaseTrainer):

         for handle in self.grad_hook_handles:
             handle.remove()
+
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+def generate_latents_from_dist(dists, seed):
+    generator = torch.Generator(device=dists.device)
+    generator.manual_seed(seed)
+    distribution = DiagonalGaussianDistribution(dists)
+    return distribution.sample(generator=generator)

This is a hack and just a proof of concept, though.

Have you considered alternatives? List them here.

Ideally, the dataloader itself would NOT cache latent_image, but would sample it from the latent distribution each time it's requested, with the latent distribution itself being the thing that is cached, and then a new sample would be drawn from the distribution each time the item is sampled for inclusion in a batch.

cheald avatar Jan 23 '25 23:01 cheald

Training robustness is improved by instead sampling a new sample from the latent distribution each time we want that image

Do you have evidence for that? I've only implemented mean sampling because everything I've seen and read makes it pretty clear this is the best approach. IIRC that this was even mentioned in some papers I read.

Actually sampling that distribution is very cheap and easy.

This is true. But to sample the distribution during training, you need to cache that distribution. Which increases the size of the latent cache by a factor of 2

I don't think there's even a theoretical reason why sampling from that distribution makes sense during training. The only reason there is a distribution at all is to fix a problem in the latent space. Similar values in pixel space should be mapped to similar values in latent space. But there is only ever one single "perfect" mapping between the pixel and latent space. (and even that isn't perfect, it's still lossy)

Nerogar avatar Jan 24 '25 17:01 Nerogar

My best evidence is subjective observations from my own training, but it does make a certain kind of sense with the math, which begins from the presumption that $q(x_0)$ is a probability distribution describing the input data, and $z \sim q(z)$ rather than $z \in{z_n, z_{n+1}, \dots, z_m}$. I've actually been chasing a separate problem (specifically, the decay in $\hat\epsilon$ variance away from 1 as $t \rightarrow 0$), but stumbled across this when trying to align all the fundamental expectations with the LDM and DDIM papers, and was pleased enough with the results that I thought it was worth mentioning. In practice, it seems to help with the adoption of small details; over longer runs, I've felt that the results were more true to my training data when I'm resampling the latent per step. This is, effectively, what would happen without latent caching in play, but this approach lets us get that sample diversity without having to re-encode the source image per batch. It matches the explicit process of the LDM paper, as well, which samples a latent as $z = \varepsilon_\mu(x) + \varepsilon_\sigma(x)\epsilon$, where $\epsilon \sim \mathcal{N}(0, I)$. It is NOT specified as $z = \varepsilon_\mu(x)$, even though it's obviously a decent approximation in practice. Another way to think about the problem is that given a true image $i$, we want to train the network to learn the distribution of latents $p$ which will decode to something approximating $i$. VAEs are lossy compressors, so I don't think it's really correct to say that there's "one true VAE sampling" for any given image; there is a distribution of encodings which approximate the input, and if the mean were necessarily the best, then the learned variance should be zero anyhow. This is going to be a very tight distribution, and the mean will certainly get "close enough" for the most part, but there are some interesting problems such as the "brain tumors" in the distribution; the log variance of the distribution is not consistent, which in practice means "to reproduce the true image $i$, it is more important to learn the exact mean of certain pixels in the latent over others" (and notably, this problem may not exist outside of the SD1.5 VAEs; I haven't done any work with this on non-1.5 models). Training on multiple samplings of the latent distribution should, in theory, help the network learn those distinctions.

I haven't evaluated variance in latent distribution variances over my dataset, but $z =~ \varepsilon_\mu(x)$ might have unintended effects if your dataset is composed of latents with a wider distribution of variances, too; the observations in the "brain tumor" post might suggest that this could actually be likely in practice even though it shouldn't matter too much in theory.

If you can remember any of the papers which have argued for just taking the mean of the latent distribution, though, I'd love to read them!

Here's a couple of quick and dirty comparisons. My local branch has been modified to eliminate all usage of CPU-based random functions in favor of torch equivalents, using a single generator seeded with global_step per step for all random functions, and torch.use_deterministic_algorithms(True) has been set, specifically to aid in evaluating this sort of change.

I ran for 10 epochs, 100 samples per epoch, with 15 actual source image files (so each sample should be repeated 6-7 times on average per epoch). Optimizer was ADOPT, lr 5-e4 for both the unet and TE, weight decay of 5e-7. This is a lora trained at rank 96/alpha 96, on the Realistic Vision 5.1 checkpoint. For these runs, no masking or noise offset was used, though I do generally use them in my training runs. A total of 10 epochs were run, so I should see 1000 total samples, and each sample should be repeated 66.66 times (or, when using latent sampling, should get 66-67 distinct samplings).

I cranked the LR way up for the purposes of demonstration; I wanted the accumulated difference to show up more quickly for these tests; in practice I use lower LRs and more epochs, which is where this change really has an effect, since it increases the number of effectively-unique samples the network sees. The accelerated LR is sufficient for this demo, but that combined with a lack of others regularizations other than the weight decay do result in some non-aesthetic samples :)

First, I did two back-to-back runs up to 5 epochs, to demonstrate that I do have the training process in fact perfectly deterministic, and that any observed variance isn't due to differences in randomization or the usage of non-deterministic algorithms:

❯ sha256sum *.jpg

ec271665f5957f7948fbe4c629ca2930c509124dfb156b3cd269421bc1e72eb0  2025-01-24_10-35-43-31317e-425124311-35-5-0.jpg
ec271665f5957f7948fbe4c629ca2930c509124dfb156b3cd269421bc1e72eb0  2025-01-24_10-42-21-2fcf70-425124311-35-5-0.jpg

(visual comparison confirms they are identical, too, though the hash should be sufficient evidence of that)

Here's a couple of my ground truth images:

Image

And here's the output of the runs. This is a sample per epoch, with the left-hand column being epochs 1-10 without per-batch latent sampling, and the right-hand column including per-batch sampling. Key observations:

  1. This change does introduce changes in the training dynamics, which do accumulate over time.
  2. To my eye, there are minor details captured in the right-hand images which improve the fidelity of the subject capture. The distortions are due to the overall aggressive learning rate and lack of regularizations, but those are separate problems to solve.

The aesthetics of the left-hand images are, I think, a little better, but I think this is actually a misleading observation -- the right-hand images (with the latent sampling) are picking up on more of the subtleties that give my source images their unique characteristics. The improved aesthetics are, I think, reflective of the trained model's somewhat diminished capacity to learn away from the base model. In practice, when I add other regularization techniques, I feel like I'm getting substantially better results over 30-40 epochs than I do with just a single latent sample.

Image

cheald avatar Jan 24 '25 20:01 cheald