const_layout icon indicating copy to clipboard operation
const_layout copied to clipboard

LayoutNet - how to train?

Open sukritiverma1996 opened this issue 2 years ago • 1 comments

Could you share a script to train LayoutNet for FID calculation? I see the model architecture file, could you maybe share the code to train it or the loss functions used?

sukritiverma1996 avatar Jun 28 '22 15:06 sukritiverma1996

LayoutNet is almost the same as Discriminator, except for where the fake samples come from. The code looks like below.

# fake layout generation
if random.random() < 0.5:
    data.x = data.x + torch.normal(0, 0.05, size=data.x.size())
    is_real = torch.zeros(batch_size)
else:
    is_real = torch.ones(batch_size)

...

# loss computation
loss_bce = F.binary_cross_entropy_with_logits(logit_disc, is_real)
loss_recl = F.cross_entropy(logit_cls, data.y)
loss_recb = F.mse_loss(bbox_pred, data.x)
loss = loss_bce + loss_recl + 10 * loss_recb

ktrk115 avatar Jun 29 '22 12:06 ktrk115