DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

Update vae.py

Open mzamini92 opened this issue 1 year ago • 2 comments

Since the DSVAE class is already inheriting from torch.nn.Module, there is no need to inherit from CUDAGraph as well. You can remove the CUDAGraph inheritance. Instead of using self.vae.requires_grad_(requires_grad=False), you can use torch.no_grad() context manager during initialization to disable gradient computation for the self.vae module. The _graph_replay_decoder, _graph_replay_encoder, and _graph_replay methods can benefit from the @torch.no_grad() decorator.

mzamini92 avatar Jun 16 '23 15:06 mzamini92

@microsoft-github-policy-service agree

@microsoft-github-policy-service agree

mzamini92 avatar Jun 16 '23 15:06 mzamini92