KerasGAN icon indicating copy to clipboard operation
KerasGAN copied to clipboard

Upgrade to Keras2?

Open EladNoy opened this issue 7 years ago • 4 comments

With Keras1 now being deprecated, a Keras2 version would be greatly appreciated. I tried converting it myself but Keras2 does not support the batchnorm mode=2 option, so it will probably require some sort of a workaround.

EladNoy avatar Jun 05 '17 15:06 EladNoy

I was stuck with your same problem. I ended up developing a batchnorm version that uses always batchnorm mode = 2. you can easily edit the keras file where bn is defined, and you can modify it so it will never use batchnorm training accumulated statistics.

engharat avatar Jul 04 '17 15:07 engharat

Can you share the code. @engharat Please.

frnk99 avatar Jul 15 '17 10:07 frnk99

Sure. Here is a link to the code: https://drive.google.com/open?id=0B0E8DCU-EnYRR2l3aV9oTkJORHc . The file needs to be put in the same folder of your script and it needs to be imported of course, then you can substitute any occurrence of BatchNormalization layer in the generator / discriminator code with the layer BatchNormGAN.

Or if you prefer the code:

`# -- coding: utf-8 -- from future import absolute_import

from keras.engine import Layer, InputSpec from keras import initializers from keras import regularizers from keras import constraints from keras import backend as K from keras.legacy import interfaces

class BatchNormGAN(Layer): """Batch normalization layer (Ioffe and Szegedy, 2014).

Normalize the activations of the previous layer at each batch,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.

# Arguments
    axis: Integer, the axis that should be normalized
        (typically the features axis).
        For instance, after a `Conv2D` layer with
        `data_format="channels_first"`,
        set `axis=1` in `BatchNormGAN`.
    momentum: Momentum for the moving average.
    epsilon: Small float added to variance to avoid dividing by zero.
    center: If True, add offset of `beta` to normalized tensor.
        If False, `beta` is ignored.
    scale: If True, multiply by `gamma`.
        If False, `gamma` is not used.
        When the next layer is linear (also e.g. `nn.relu`),
        this can be disabled since the scaling
        will be done by the next layer.
    beta_initializer: Initializer for the beta weight.
    gamma_initializer: Initializer for the gamma weight.
    moving_mean_initializer: Initializer for the moving mean.
    moving_variance_initializer: Initializer for the moving variance.
    beta_regularizer: Optional regularizer for the beta weight.
    gamma_regularizer: Optional regularizer for the gamma weight.
    beta_constraint: Optional constraint for the beta weight.
    gamma_constraint: Optional constraint for the gamma weight.

# Input shape
    Arbitrary. Use the keyword argument `input_shape`
    (tuple of integers, does not include the samples axis)
    when using this layer as the first layer in a model.

# Output shape
    Same shape as input.

# References
    - [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
"""

@interfaces.legacy_batchnorm_support
def __init__(self,
             axis=-1,
             momentum=0.99,
             epsilon=1e-3,
             center=True,
             scale=True,
             beta_initializer='zeros',
             gamma_initializer='ones',
             moving_mean_initializer='zeros',
             moving_variance_initializer='ones',
             beta_regularizer=None,
             gamma_regularizer=None,
             beta_constraint=None,
             gamma_constraint=None,
             **kwargs):
    super(BatchNormGAN, self).__init__(**kwargs)
    self.supports_masking = True
    self.axis = axis
    self.momentum = momentum
    self.epsilon = epsilon
    self.center = center
    self.scale = scale
    self.beta_initializer = initializers.get(beta_initializer)
    self.gamma_initializer = initializers.get(gamma_initializer)
    self.moving_mean_initializer = initializers.get(moving_mean_initializer)
    self.moving_variance_initializer = initializers.get(moving_variance_initializer)
    self.beta_regularizer = regularizers.get(beta_regularizer)
    self.gamma_regularizer = regularizers.get(gamma_regularizer)
    self.beta_constraint = constraints.get(beta_constraint)
    self.gamma_constraint = constraints.get(gamma_constraint)

def build(self, input_shape):
    dim = input_shape[self.axis]
    if dim is None:
        raise ValueError('Axis ' + str(self.axis) + ' of '
                         'input tensor should have a defined dimension '
                         'but the layer received an input with shape ' +
                         str(input_shape) + '.')
    self.input_spec = InputSpec(ndim=len(input_shape),
                                axes={self.axis: dim})
    shape = (dim,)

    if self.scale:
        self.gamma = self.add_weight(shape,
                                     name='gamma',
                                     initializer=self.gamma_initializer,
                                     regularizer=self.gamma_regularizer,
                                     constraint=self.gamma_constraint)
    else:
        self.gamma = None
    if self.center:
        self.beta = self.add_weight(shape,
                                    name='beta',
                                    initializer=self.beta_initializer,
                                    regularizer=self.beta_regularizer,
                                    constraint=self.beta_constraint)
    else:
        self.beta = None
    self.moving_mean = self.add_weight(
        shape,
        name='moving_mean',
        initializer=self.moving_mean_initializer,
        trainable=False)
    self.moving_variance = self.add_weight(
        shape,
        name='moving_variance',
        initializer=self.moving_variance_initializer,
        trainable=False)
    self.built = True

def call(self, inputs, training=None):
    input_shape = K.int_shape(inputs)
    # Prepare broadcasting shape.
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]

    normed, mean, variance = K.normalize_batch_in_training(
        inputs, self.gamma, self.beta, reduction_axes,
        epsilon=self.epsilon)

    return normed #K.in_train_phase(normed,
                   #         normalize_inference,
                   #         training=True)

def get_config(self):
    config = {
        'axis': self.axis,
        'momentum': self.momentum,
        'epsilon': self.epsilon,
        'center': self.center,
        'scale': self.scale,
        'beta_initializer': initializers.serialize(self.beta_initializer),
        'gamma_initializer': initializers.serialize(self.gamma_initializer),
        'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer),
        'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer),
        'beta_regularizer': regularizers.serialize(self.beta_regularizer),
        'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
        'beta_constraint': constraints.serialize(self.beta_constraint),
        'gamma_constraint': constraints.serialize(self.gamma_constraint)
    }
    base_config = super(BatchNormGAN, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))``

engharat avatar Jul 16 '17 21:07 engharat

thank you! @engharat

frnk99 avatar Jul 22 '17 17:07 frnk99