DeepSpeed
DeepSpeed copied to clipboard
Update vae.py
Since the DSVAE
class is already inheriting from torch.nn.Module
, there is no need to inherit from CUDAGraph
as well. You can remove the CUDAGraph
inheritance. Instead of using self.vae.requires_grad_(requires_grad=False)
, you can use torch.no_grad()
context manager during initialization to disable gradient computation for the self.vae
module. The _graph_replay_decoder
, _graph_replay_encoder
, and _graph_replay
methods can benefit from the @torch.no_grad()
decorator.
@microsoft-github-policy-service agree
@microsoft-github-policy-service agree