Guo He

Results 1 comments of Guo He

I have init all modules in configure_sharded_model def configure_sharded_model(self): if self.use_colossalai: rank_zero_info("Configure sharded model for LatentDiffusion") self.model = DiffusionWrapper(self.unet_config, self.conditioning_key) count_params(self.model, verbose=True) if self.use_ema: self.model_ema = LitEma(self.model) if self.ckpt is...