image-super-resolution
image-super-resolution copied to clipboard
Gray scale training, "input shape" error message
Does this model support gray-scale (1D) training ?
To train gray scale training images, I add "c_dim" parameter on train code like below
rdn = RDN(arch_params={'C':6, 'D':20, 'G':64, 'G0':64, 'x':scale}, patch_size=lr_train_patch_size, c_dim=1)
But, I got below error message
ValueError: Input 0 of layer Conv_1 is incompatible with the layer: expected axis -1 of input shape to have value 3 but received input with shape [None, 64, 64, 1]
This is my code.
from ISR.models import RDN
from ISR.models import Discriminator
from ISR.models import Cut_VGG19
lr_train_patch_size = 32
layers_to_extract = [5, 9]
scale = 2
hr_train_patch_size = lr_train_patch_size * scale
rdn = RDN(arch_params={'C':6, 'D':20, 'G':64, 'G0':64, 'x':scale}, patch_size=lr_train_patch_size, c_dim=1)
f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)
from ISR.train import Trainer
loss_weights = {
'generator': 0.0,
'feature_extractor': 0.08333,
'discriminator': 0.01
}
losses = {
'generator': 'mae',
'feature_extractor': 'mse',
'discriminator': 'binary_crossentropy'
}
log_dirs = {'logs': './logs', 'weights': './weights'}
learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}
flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}
trainer = Trainer(
generator=rdn,
discriminator=discr,
feature_extractor=f_ext,
lr_train_dir='data/gray/scale_x2/train/LR/',
hr_train_dir='data/gray/scale_x2/train/HR/',
lr_valid_dir='data/gray/scale_x2/val/LR/',
hr_valid_dir='data/gray/scale_x2/val/HR/',
loss_weights=loss_weights,
learning_rate=learning_rate,
flatness=flatness,
dataname='gray/scale_x2',
log_dirs=log_dirs,
metrics={'generator': 'PSNR'},
weights_generator=None,
weights_discriminator=None,
n_validation=100,
)
trainer.train(
epochs=50,
steps_per_epoch=100,
batch_size=20,
monitored_metrics={'val_generator_loss': 'max'}
)
could you post the entire error logs?