pix2pix icon indicating copy to clipboard operation
pix2pix copied to clipboard

Multi Channel Issue

Open posadad1 opened this issue 11 months ago • 0 comments

Hi,

I am trying to use a six-channel image; the output is only a three-channel image. I added a quick fix to the training method to let me concatenate the images to make the six channels by adding an extra folder parameter passed down to PairedImageDatastore. The third folder would load and add the image to the A image.

        function data = read(obj)
            imagesA = obj.ImagesA.read();
            imagesB = obj.ImagesB.read();
            imagesC = obj.ImagesC.read();
            
            % for batch size 1 imagedatastore doesn't wrap in a cell
            if ~iscell(imagesA)
                imagesA(:,:,4:6) = imagesC;
                imagesA = {imagesA};
                imagesB = {imagesB};
            end
           [transformedA, transformedB] = ...
                p2p.data.transformImagePair(imagesA, imagesB, ...
                                            obj.PreSize, obj.CropSize, ...
                                            obj.Augmenter);
            [A, B] = obj.normaliseImages(transformedA, transformedB);
            data = table(A, B);
        end

Before the training, I changed the options as:

options = p2p.trainingOptions('InputChannels',6,'OutputChannels',3);

Then, when training starts, it completes the first iteration of the first epoch, but then it breaks. With the following output:

epoch: 1, it: 50, G: 77.888550 (L1: 0.771494, GAN: 0.739171), D: 0.664209
Error using dlnetwork/forward
Layer 'inputImage': Invalid input data. Invalid size of channel dimension. Layer expects input with channel dimension size
6 but received input with size 3.

Error in p2p.vis.TrainingPlot/updateImages (line 62)
            output = tanh(generator.forward(obj.ExampleInputs));

Error in p2p.vis.TrainingPlot/update (line 47)
            obj.updateImages(generator)

Error in p2p.train (line 100)
                    trainingPlot.update(logArgs{:}, g);

Error in trainDepth (line 11)
p2pModel = p2p.train(labelFolder, targetFolder,options);

posadad1 avatar Mar 03 '24 22:03 posadad1