GANs-JAX
GANs-JAX copied to clipboard
add more sample generated for 1_GANs.ipynb
Thanks for these demo nbs!
I've a question concerning 1_GANs.ipynbthat I have modified to get
2D multi-blob distribution: ie each X_train are (x_1,x_2) couples in [-1,1]^2
I have modified the Generator/Discriminator Modules to a simpler Dense/BatchNorm/Relu sequence as
class TrainState(train_state.TrainState):
batch_stats: Any # use only for BatchNorm
class Generator(nn.Module):
dtype: Any = jnp.float32
@nn.compact
def __call__(self, z: jnp.ndarray, train: bool = True):
batch_norm = partial(nn.BatchNorm, use_running_average=not train, axis=-1,
scale_init=normal_init(0.02), dtype=self.dtype)
x = z.reshape((args['batch_size_p'], args['z_dim']))
x = nn.Dense(512,name=f'Gen_layers_1', dtype=self.dtype)(x)
x = batch_norm()(x)
x = nn.relu(x)
x = nn.Dense(512,name=f'Gen_layers_2', dtype=self.dtype)(x)
x = batch_norm()(x)
x = nn.relu(x)
x = nn.Dense(512,name=f'Gen_layers_3', dtype=self.dtype)(x)
x = nn.relu(x)
x = nn.Dense(args['x_dim'],name=f'Gen_layers_4', dtype=self.dtype)(x)
return x
class Discriminator(nn.Module):
dtype: Any = jnp.float32
@nn.compact
def __call__(self, x: jnp.ndarray, train: bool = True):
batch_norm = partial(nn.BatchNorm, use_running_average=not train, axis=-1,
scale_init=normal_init(0.02), dtype=self.dtype)
x = nn.Dense(512,name=f'Dis_layers_1',dtype=self.dtype)(x)
x = batch_norm()(x)
x = nn.relu(x)
x = nn.Dense(512,name=f'Dis_layers_2',dtype=self.dtype)(x)
x = batch_norm()(x)
x = nn.relu(x)
x = nn.Dense(512,name=f'Dis_layers_3',dtype=self.dtype)(x)
x = batch_norm()(x)
x = nn.relu(x)
x = nn.Dense(args['x_dim'],name=f'Dis_layers_4',dtype=self.dtype)(x)
x = nn.sigmoid(x) # [0,1]
return x #x.reshape((args['batch_size_p'], -1))
I manage to get same results keeping your Training phase after epoch 100
Now, I would like to get more samples on each epoch image saved but I do not manage
to change the args['batch_size'] to something else, can you give me an advise, please?
generator_input = jax.random.normal(key, (args['batch_size'], args['z_dim']))
generator_input = shard(generator_input)
Also what I can do simply to improve the mdelization or the loss definition to get between model learning? Thanks