DiscoGAN-pytorch
DiscoGAN-pytorch copied to clipboard
Cuda runtime runtime error(8)
class Generator(nn.Module):
def __init__(self, input_size, output_size, hidden_dims):
super(Generator, self).__init__()
self.layers = []
prev_dim = input_size
for hidden_dim in hidden_dims:
self.layers.append(nn.Linear(prev_dim, hidden_dim))
self.layers.append(nn.ReLU(True))
prev_dim = hidden_dim
self.layers.append(nn.Linear(prev_dim, output_size))
self.layer_module = ListModule(*self.layers)
def forward(self, x):
out = x
for layer in self.layers:
out = layer(out)
return out
class Discriminator(nn.Module):
def __init__(self, input_size, output_size, hidden_dims):
super(Discriminator, self).__init__()
self.layers = []
prev_dim = input_size
for idx, hidden_dim in enumerate(hidden_dims):
self.layers.append(nn.Linear(prev_dim, hidden_dim))
self.layers.append(nn.ReLU(True))
prev_dim = hidden_dim
self.layers.append(nn.Linear(prev_dim, output_size))
self.layers.append(nn.Sigmoid())
self.layer_module = ListModule(*self.layers)
def forward(self, x):
out = x
for layer in self.layers:
out = layer(out)
return out.view(-1, 1)
network
hidden_dim = 128
g_num_layer = 3
d_num_layer = 5
G_AB = Generator(2, 2, [hidden_dim] * g_num_layer)
G_BA = Generator(2, 2, [hidden_dim] * g_num_layer)
D_A = Discriminator(2, 1, [hidden_dim] * d_num_layer)
D_B = Discriminator(2, 1, [hidden_dim] * d_num_layer)
G_AB.cuda()
G_BA.cuda()
D_A.cuda()
D_B.cuda()
optimizer
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
d = nn.MSELoss()
bce = nn.BCELoss()
optimizer_d = torch.optim.Adam(
chain(D_A.parameters(), D_B.parameters()), lr=lr, betas=(beta1, beta2))
optimizer_g = torch.optim.Adam(
chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(beta1, beta2))
training
num_epoch = 50000
real_label = 1
fake_label = 0
real_tensor = Variable(torch.FloatTensor(batch_size).cuda())
_ = real_tensor.data.fill_(real_label)
print(real_tensor.sum())
fake_tensor = Variable(torch.FloatTensor(batch_size).cuda())
_ = fake_tensor.data.fill_(fake_label)
print(fake_tensor.sum())
RuntimeError Traceback (most recent call last)
RuntimeError: cuda runtime error (8) : invalid device function at /py/conda-bld/pytorch_1493677666423/work/torch/lib/THC/generic/THCTensorMath.cu:15