image-super-resolution icon indicating copy to clipboard operation
image-super-resolution copied to clipboard

Gray scale training, "input shape" error message

Open 89douner opened this issue 4 years ago • 1 comments

Does this model support gray-scale (1D) training ?

To train gray scale training images, I add "c_dim" parameter on train code like below

rdn = RDN(arch_params={'C':6, 'D':20, 'G':64, 'G0':64, 'x':scale}, patch_size=lr_train_patch_size, c_dim=1)

But, I got below error message

ValueError: Input 0 of layer Conv_1 is incompatible with the layer: expected axis -1 of input shape to have value 3 but received input with shape [None, 64, 64, 1]

This is my code.

from ISR.models import RDN
from ISR.models import Discriminator
from ISR.models import Cut_VGG19

lr_train_patch_size = 32
layers_to_extract = [5, 9]
scale = 2
hr_train_patch_size = lr_train_patch_size * scale

rdn = RDN(arch_params={'C':6, 'D':20, 'G':64, 'G0':64, 'x':scale}, patch_size=lr_train_patch_size, c_dim=1)
f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)

from ISR.train import Trainer
loss_weights = {
  'generator': 0.0,
  'feature_extractor': 0.08333,
  'discriminator': 0.01
}
losses = {
  'generator': 'mae',
  'feature_extractor': 'mse',
  'discriminator': 'binary_crossentropy'
}

log_dirs = {'logs': './logs', 'weights': './weights'}
learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}
flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}

trainer = Trainer(
    generator=rdn,
    discriminator=discr,
    feature_extractor=f_ext,

    lr_train_dir='data/gray/scale_x2/train/LR/',
    hr_train_dir='data/gray/scale_x2/train/HR/',
    lr_valid_dir='data/gray/scale_x2/val/LR/',
    hr_valid_dir='data/gray/scale_x2/val/HR/',

    loss_weights=loss_weights,
    learning_rate=learning_rate,
    flatness=flatness,
    dataname='gray/scale_x2',
    log_dirs=log_dirs,
    metrics={'generator': 'PSNR'},
    weights_generator=None,
    weights_discriminator=None,
    n_validation=100,
)
trainer.train(
    epochs=50,
    steps_per_epoch=100,
    batch_size=20,
    monitored_metrics={'val_generator_loss': 'max'}
)

89douner avatar Apr 21 '20 05:04 89douner

could you post the entire error logs?

cfrancesco avatar Jun 18 '20 13:06 cfrancesco