mmgeneration icon indicating copy to clipboard operation
mmgeneration copied to clipboard

How to train my data with a MMgeneration model alone?

Open mikaizhu opened this issue 2 years ago • 3 comments

Hi? how i can use MMgeneration model alone ? for example, I want to use Conditional gan such as SAGAN or SNGAN-Proj;

I want to use MM model like:

dataloader = MyDataloader()
generator = SAGAN_generator().to(device)
discriminator = SAGAN_discriminator().to(device)

for feature, label in dataloader:
   feature = feature.to(device)
   label = label.to(device)
   fake_data = generator(z, num_classes)
   ....

I've tried to separate out your model but i failed, what should I do?

thanks!

mikaizhu avatar May 17 '22 15:05 mikaizhu

Can you show how you instantiate SAGAN_generator and SAGAN_discriminator and the console output? We prefer to build models by build_model(MODEL_CONFIG). Configs for SAGAN can be found in this URL.

LeoXing1996 avatar May 18 '22 02:05 LeoXing1996

Can you show how you instantiate SAGAN_generator and SAGAN_discriminator and the console output? We prefer to build models by build_model(MODEL_CONFIG). Configs for SAGAN can be found in this URL.

Thanks for your reply! @LeoXing1996

from mmgen.models import build_model
from mmcv import Config

cfg = Config.fromfile('/userhome/mmgeneration/configs/sagan/sagan_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py')
model = build_model(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

then i get the output of model:

BasicConditionalGAN(
  (generator): SNGANGenerator(
    (noise2feat): Linear(in_features=128, out_features=4096, bias=True)
    (conv_blocks): ModuleList(
      (0): SNGANGenResBlock(
        (activate): ReLU(inplace=True)
        (upsample): Upsample(scale_factor=2.0, mode=nearest)
        (conv_1): SNConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (conv_2): SNConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (norm_1): SNConditionNorm(
          (norm): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
          (weight_embedding): Embedding(10, 256)
          (bias_embedding): Embedding(10, 256)
        )
        (norm_2): SNConditionNorm(
          (norm): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
          (weight_embedding): Embedding(10, 256)
          (bias_embedding): Embedding(10, 256)
        )
        (shortcut): SNConvModule(
          (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (1): SNGANGenResBlock(
        (activate): ReLU(inplace=True)
        (upsample): Upsample(scale_factor=2.0, mode=nearest)
        (conv_1): SNConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (conv_2): SNConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (norm_1): SNConditionNorm(
          (norm): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
          (weight_embedding): Embedding(10, 256)
          (bias_embedding): Embedding(10, 256)
        )
        (norm_2): SNConditionNorm(
          (norm): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
          (weight_embedding): Embedding(10, 256)
          (bias_embedding): Embedding(10, 256)
        )
        (shortcut): SNConvModule(
          (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (2): SelfAttentionBlock(
        (theta): SNConvModule(
          (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (phi): SNConvModule(
          (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (g): SNConvModule(
          (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (o): SNConvModule(
          (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
      )
      (3): SNGANGenResBlock(
        (activate): ReLU(inplace=True)
        (upsample): Upsample(scale_factor=2.0, mode=nearest)
        (conv_1): SNConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (conv_2): SNConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (norm_1): SNConditionNorm(
          (norm): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
          (weight_embedding): Embedding(10, 256)
          (bias_embedding): Embedding(10, 256)
        )
        (norm_2): SNConditionNorm(
          (norm): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
          (weight_embedding): Embedding(10, 256)
          (bias_embedding): Embedding(10, 256)
        )
        (shortcut): SNConvModule(
          (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
    (to_rgb): ConvModule(
      (conv): Conv2d(256, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (final_act): Tanh()
  )
  (discriminator): ProjDiscriminator(
    (from_rgb): SNGANDiscHeadResBlock(
      (activate): ReLU(inplace=True)
      (conv_1): SNConvModule(
        (conv): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (conv_2): SNConvModule(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (downsample): AvgPool2d(kernel_size=2, stride=2, padding=0)
      (shortcut): SNConvModule(
        (conv): Conv2d(3, 128, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (conv_blocks): ModuleList(
      (0): SelfAttentionBlock(
        (theta): SNConvModule(
          (conv): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (phi): SNConvModule(
          (conv): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (g): SNConvModule(
          (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (o): SNConvModule(
          (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
      )
      (1): SNGANDiscResBlock(
        (activate): ReLU(inplace=True)
        (conv_1): SNConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (conv_2): SNConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (downsample): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (shortcut): SNConvModule(
          (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (2): SNGANDiscResBlock(
        (activate): ReLU(inplace=True)
        (conv_1): SNConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (conv_2): SNConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (3): SNGANDiscResBlock(
        (activate): ReLU(inplace=True)
        (conv_1): SNConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (conv_2): SNConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (decision): Linear(in_features=128, out_features=1, bias=True)
    (proj_y): Embedding(10, 128)
    (activate): ReLU(inplace=True)
  )
  (gan_loss): GANLoss(
    (loss): ReLU()
  )
)

I want to get the discriminator and generator from model then use it to my data, such as:

z = torch.randn(10, 128, device=device)
fake_labels = torch.arange(0, 10, dtype=torch.long, device=device)

image = torch.randn(10, 2, 32, 32, device=device)
label = torch.arange(0, 10, dtype=torch.long, device=device)

model_D(fake_image, fake_labels)

what should i do ?

mikaizhu avatar May 18 '22 04:05 mikaizhu

You can use the following code:

model_G = model.generator
model_D = model.discriminator

LeoXing1996 avatar May 20 '22 02:05 LeoXing1996