physicsnemo icon indicating copy to clipboard operation
physicsnemo copied to clipboard

[DRAFT]: Refactor of diffusion samplers

Open CharlelieLrt opened this issue 3 months ago • 0 comments

PhysicsNeMo Pull Request

Description

Note: this PR is currently a draft and is open for feedback and suggestions. It will not be merged in its current form and might be broken down into smaller PRs.

Objectives

Overarching objective: refactor and improve all utilities and and models related to diffusion in order to consolidate these components into a framework. More precisely, these diffusion components should:

  • Support out of the box all current diffusion use-cases (CorrDiff, cBottle, StormCast, etc.)
  • Be composable
  • Be extendable
  • Be well-documented

PR objective: focuses on the EDM samplers (stochastic_sampler and deterministic_sampler). For these samplers, the objective is to be:

  1. Agnostic to the diffusion model
  2. Agnostic to the modality of the latent state x
  3. Agnostic to the modality of the conditioning (for conditional diffusion models)
  4. Support a large range of guidance for plug-and-play generation (e.g. DPS)
  5. Support multiple implementation of multi-diffusion

Solutions

  1. Refactored the stochastic_sampler functional interface into an object-oriented interface to facilitate future extensibility.

  2. Model agnostic: relies on a callback whose signature is assumed invariant. To be able to satisfy this invariant constraint we also provide a surface adapter (i.e. a thin wrapper) that modifies the signature of any given Module to ensure compliance.

  3. Latent state agnostic: simple refactors to avoid unnecessary assumptions on the shape of the latent state.

  4. Conditioning agnostic: all conditioning variables are packed into a dict of tensors. The sampler never accesses the conditioning as the model is responsible for handling the conditioning ops.

  5. Plug-and-play guidance: relies on callbacks passed to the sampler. Introduces a guidance API to facilitate creating these callbacks, and ensure compliance with the sampler requirements. For now two types of guidance are provided (model-based guidance for DPS, and data consistency guidance for inpainting/outpainting/channel infilling, etc.)

  6. Multi-diffusion: TBD. For now the plan is to defer multi-diffusion ops to a dedicated model wrapper.

Remaining items

- Multi-diffusion: in their current implementation, the samplers use a patching object to extract patches from the latent state x; it also calls methods from the model to extract the global positional embedding of these patches. These strong assumptions on the model APIs are not compatible with the objective (1) above. A better solution might be to defer all multi-diffusion ops to a model wrapper that extract patches, get the positional embeddings ,etc...

- Guidance: only support pre-defined guidance, no mechanism to allow user-defined guidance. Extend the range of available off-the-shelf guidances.

- Model-based guidance: only support a model that processes batches, and that is implemented in pytorch and compatible with torch.autograd.

Checklist

  • [ ] I am familiar with the Contributing Guidelines.
  • [ ] New or existing tests cover these changes.
  • [ ] The documentation is up to date with these changes.
  • [ ] The CHANGELOG.md is up to date with these changes.
  • [ ] An issue is linked to this pull request.

Dependencies

CharlelieLrt avatar Sep 05 '25 19:09 CharlelieLrt