mmgeneration
mmgeneration copied to clipboard
How to train my data with a MMgeneration model alone?
Hi? how i can use MMgeneration model alone ? for example, I want to use Conditional gan such as SAGAN or SNGAN-Proj;
I want to use MM model like:
dataloader = MyDataloader()
generator = SAGAN_generator().to(device)
discriminator = SAGAN_discriminator().to(device)
for feature, label in dataloader:
feature = feature.to(device)
label = label.to(device)
fake_data = generator(z, num_classes)
....
I've tried to separate out your model but i failed, what should I do?
thanks!
Can you show how you instantiate SAGAN_generator
and SAGAN_discriminator
and the console output?
We prefer to build models by build_model(MODEL_CONFIG)
. Configs for SAGAN can be found in this URL.
Can you show how you instantiate
SAGAN_generator
andSAGAN_discriminator
and the console output? We prefer to build models bybuild_model(MODEL_CONFIG)
. Configs for SAGAN can be found in this URL.
Thanks for your reply! @LeoXing1996
from mmgen.models import build_model
from mmcv import Config
cfg = Config.fromfile('/userhome/mmgeneration/configs/sagan/sagan_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py')
model = build_model(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
then i get the output of model
:
BasicConditionalGAN(
(generator): SNGANGenerator(
(noise2feat): Linear(in_features=128, out_features=4096, bias=True)
(conv_blocks): ModuleList(
(0): SNGANGenResBlock(
(activate): ReLU(inplace=True)
(upsample): Upsample(scale_factor=2.0, mode=nearest)
(conv_1): SNConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(conv_2): SNConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(norm_1): SNConditionNorm(
(norm): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(weight_embedding): Embedding(10, 256)
(bias_embedding): Embedding(10, 256)
)
(norm_2): SNConditionNorm(
(norm): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(weight_embedding): Embedding(10, 256)
(bias_embedding): Embedding(10, 256)
)
(shortcut): SNConvModule(
(conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(1): SNGANGenResBlock(
(activate): ReLU(inplace=True)
(upsample): Upsample(scale_factor=2.0, mode=nearest)
(conv_1): SNConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(conv_2): SNConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(norm_1): SNConditionNorm(
(norm): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(weight_embedding): Embedding(10, 256)
(bias_embedding): Embedding(10, 256)
)
(norm_2): SNConditionNorm(
(norm): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(weight_embedding): Embedding(10, 256)
(bias_embedding): Embedding(10, 256)
)
(shortcut): SNConvModule(
(conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(2): SelfAttentionBlock(
(theta): SNConvModule(
(conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(phi): SNConvModule(
(conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(g): SNConvModule(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(o): SNConvModule(
(conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(3): SNGANGenResBlock(
(activate): ReLU(inplace=True)
(upsample): Upsample(scale_factor=2.0, mode=nearest)
(conv_1): SNConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(conv_2): SNConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(norm_1): SNConditionNorm(
(norm): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(weight_embedding): Embedding(10, 256)
(bias_embedding): Embedding(10, 256)
)
(norm_2): SNConditionNorm(
(norm): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(weight_embedding): Embedding(10, 256)
(bias_embedding): Embedding(10, 256)
)
(shortcut): SNConvModule(
(conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
)
(to_rgb): ConvModule(
(conv): Conv2d(256, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(final_act): Tanh()
)
(discriminator): ProjDiscriminator(
(from_rgb): SNGANDiscHeadResBlock(
(activate): ReLU(inplace=True)
(conv_1): SNConvModule(
(conv): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(conv_2): SNConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(downsample): AvgPool2d(kernel_size=2, stride=2, padding=0)
(shortcut): SNConvModule(
(conv): Conv2d(3, 128, kernel_size=(1, 1), stride=(1, 1))
)
)
(conv_blocks): ModuleList(
(0): SelfAttentionBlock(
(theta): SNConvModule(
(conv): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(phi): SNConvModule(
(conv): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(g): SNConvModule(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(o): SNConvModule(
(conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): SNGANDiscResBlock(
(activate): ReLU(inplace=True)
(conv_1): SNConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(conv_2): SNConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(downsample): AvgPool2d(kernel_size=2, stride=2, padding=0)
(shortcut): SNConvModule(
(conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
)
)
(2): SNGANDiscResBlock(
(activate): ReLU(inplace=True)
(conv_1): SNConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(conv_2): SNConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(3): SNGANDiscResBlock(
(activate): ReLU(inplace=True)
(conv_1): SNConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(conv_2): SNConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(decision): Linear(in_features=128, out_features=1, bias=True)
(proj_y): Embedding(10, 128)
(activate): ReLU(inplace=True)
)
(gan_loss): GANLoss(
(loss): ReLU()
)
)
I want to get the discriminator and generator from model then use it to my data, such as:
z = torch.randn(10, 128, device=device)
fake_labels = torch.arange(0, 10, dtype=torch.long, device=device)
image = torch.randn(10, 2, 32, 32, device=device)
label = torch.arange(0, 10, dtype=torch.long, device=device)
model_D(fake_image, fake_labels)
what should i do ?
You can use the following code:
model_G = model.generator
model_D = model.discriminator