[DRAFT]: Refactor of diffusion samplers
PhysicsNeMo Pull Request
Description
Note: this PR is currently a draft and is open for feedback and suggestions. It will not be merged in its current form and might be broken down into smaller PRs.
Objectives
Overarching objective: refactor and improve all utilities and and models related to diffusion in order to consolidate these components into a framework. More precisely, these diffusion components should:
- Support out of the box all current diffusion use-cases (CorrDiff, cBottle, StormCast, etc.)
- Be composable
- Be extendable
- Be well-documented
PR objective: focuses on the EDM samplers (stochastic_sampler and deterministic_sampler). For these samplers, the objective is to be:
- Agnostic to the diffusion model
- Agnostic to the modality of the latent state
x - Agnostic to the modality of the conditioning (for conditional diffusion models)
- Support a large range of guidance for plug-and-play generation (e.g. DPS)
- Support multiple implementation of multi-diffusion
Solutions
-
Refactored the
stochastic_samplerfunctional interface into an object-oriented interface to facilitate future extensibility. -
Model agnostic: relies on a callback whose signature is assumed invariant. To be able to satisfy this invariant constraint we also provide a surface adapter (i.e. a thin wrapper) that modifies the signature of any given
Moduleto ensure compliance. -
Latent state agnostic: simple refactors to avoid unnecessary assumptions on the shape of the latent state.
-
Conditioning agnostic: all conditioning variables are packed into a dict of tensors. The sampler never accesses the conditioning as the model is responsible for handling the conditioning ops.
-
Plug-and-play guidance: relies on callbacks passed to the sampler. Introduces a guidance API to facilitate creating these callbacks, and ensure compliance with the sampler requirements. For now two types of guidance are provided (model-based guidance for DPS, and data consistency guidance for inpainting/outpainting/channel infilling, etc.)
-
Multi-diffusion: TBD. For now the plan is to defer multi-diffusion ops to a dedicated model wrapper.
Remaining items
- Multi-diffusion: in their current implementation, the samplers use a patching object to extract patches from the latent state x; it also calls methods from the model to extract the global positional embedding of these patches. These strong assumptions on the model APIs are not compatible with the objective (1) above. A better solution might be to defer all multi-diffusion ops to a model wrapper that extract patches, get the positional embeddings ,etc...
- Guidance: only support pre-defined guidance, no mechanism to allow user-defined guidance. Extend the range of available off-the-shelf guidances.
- Model-based guidance: only support a model that processes batches, and that is implemented in pytorch and compatible with torch.autograd.
Checklist
- [ ] I am familiar with the Contributing Guidelines.
- [ ] New or existing tests cover these changes.
- [ ] The documentation is up to date with these changes.
- [ ] The CHANGELOG.md is up to date with these changes.
- [ ] An issue is linked to this pull request.