const_layout
const_layout copied to clipboard
LayoutNet - how to train?
Could you share a script to train LayoutNet for FID calculation? I see the model architecture file, could you maybe share the code to train it or the loss functions used?
LayoutNet is almost the same as Discriminator, except for where the fake samples come from. The code looks like below.
# fake layout generation
if random.random() < 0.5:
data.x = data.x + torch.normal(0, 0.05, size=data.x.size())
is_real = torch.zeros(batch_size)
else:
is_real = torch.ones(batch_size)
...
# loss computation
loss_bce = F.binary_cross_entropy_with_logits(logit_disc, is_real)
loss_recl = F.cross_entropy(logit_cls, data.y)
loss_recb = F.mse_loss(bbox_pred, data.x)
loss = loss_bce + loss_recl + 10 * loss_recb