GANs-JAX icon indicating copy to clipboard operation
GANs-JAX copied to clipboard

add more sample generated for 1_GANs.ipynb

Open jecampagne opened this issue 9 months ago • 0 comments

Thanks for these demo nbs!

I've a question concerning 1_GANs.ipynbthat I have modified to get 2D multi-blob distribution: ie each X_train are (x_1,x_2) couples in [-1,1]^2

Image

I have modified the Generator/Discriminator Modules to a simpler Dense/BatchNorm/Relu sequence as

class TrainState(train_state.TrainState):
  batch_stats: Any # use only for BatchNorm

class Generator(nn.Module):
  dtype: Any = jnp.float32

  @nn.compact
  def __call__(self, z: jnp.ndarray, train: bool = True):
    batch_norm = partial(nn.BatchNorm, use_running_average=not train, axis=-1, 
                         scale_init=normal_init(0.02), dtype=self.dtype)

    x = z.reshape((args['batch_size_p'], args['z_dim']))
    x = nn.Dense(512,name=f'Gen_layers_1', dtype=self.dtype)(x)
    x = batch_norm()(x)
    x = nn.relu(x)
    x = nn.Dense(512,name=f'Gen_layers_2', dtype=self.dtype)(x)
    x = batch_norm()(x)
    x = nn.relu(x)
    x = nn.Dense(512,name=f'Gen_layers_3', dtype=self.dtype)(x)
    x = nn.relu(x)
    x = nn.Dense(args['x_dim'],name=f'Gen_layers_4', dtype=self.dtype)(x)
    return x


class Discriminator(nn.Module):
  dtype: Any = jnp.float32

  @nn.compact
  def __call__(self, x: jnp.ndarray, train: bool = True):
    batch_norm = partial(nn.BatchNorm, use_running_average=not train, axis=-1, 
                         scale_init=normal_init(0.02), dtype=self.dtype)

    x = nn.Dense(512,name=f'Dis_layers_1',dtype=self.dtype)(x)
    x = batch_norm()(x)
    x = nn.relu(x)
    x = nn.Dense(512,name=f'Dis_layers_2',dtype=self.dtype)(x)
    x = batch_norm()(x)
    x = nn.relu(x)
    x = nn.Dense(512,name=f'Dis_layers_3',dtype=self.dtype)(x)
    x = batch_norm()(x)
    x = nn.relu(x)
    x = nn.Dense(args['x_dim'],name=f'Dis_layers_4',dtype=self.dtype)(x)
    x = nn.sigmoid(x) # [0,1]
    return x #x.reshape((args['batch_size_p'], -1))

I manage to get same results keeping your Training phase after epoch 100

Image

Now, I would like to get more samples on each epoch image saved but I do not manage to change the args['batch_size'] to something else, can you give me an advise, please?

generator_input = jax.random.normal(key, (args['batch_size'], args['z_dim']))
generator_input = shard(generator_input)

Also what I can do simply to improve the mdelization or the loss definition to get between model learning? Thanks

jecampagne avatar Jan 23 '25 14:01 jecampagne