Guo He
Results
1
comments of
Guo He
I have init all modules in configure_sharded_model def configure_sharded_model(self): if self.use_colossalai: rank_zero_info("Configure sharded model for LatentDiffusion") self.model = DiffusionWrapper(self.unet_config, self.conditioning_key) count_params(self.model, verbose=True) if self.use_ema: self.model_ema = LitEma(self.model) if self.ckpt is...