tf_unet icon indicating copy to clipboard operation
tf_unet copied to clipboard

Ground truth completely ignored

Open mimre-anl opened this issue 7 years ago • 12 comments

Hi, I'm trying to use tf_unet to do some kind of outlier detection.

My data is a one channel double array, and my ground truth is the same, but with only 0s and 1s. Drawing the images is totally fine, but the network doesn't seem to care about the ground truth and just learns where in the input data/image the value is 0 (min value). The loss reduces over time, but then levels out at some epoch and the output doesn't change much. Typically after the first epoch, the network has learned where the "0"s are in the input data.

Any idea what could be wrong there? (below init, epoch 0, and epoch 10) Init: _init

Epoch 0: epoch_0

Epoch 10: epoch_10

mimre-anl avatar Jun 08 '18 18:06 mimre-anl

  • Have you tried to train it for many more epochs? 10 is most likely not enough.
  • Then you could experiment with the number of filters and/or layers.
  • Try a different loss (dice)
  • Maybe you have to preprocess the input data

There seems to be quite a class imbalance. Can you "zoom" to the relevant regions and avoid the regions where you don't expect both classes?

jakeret avatar Jun 08 '18 19:06 jakeret

I've run it so far for 20 epochs, but after about 5 there is no change. How wold you suggest pre processing the data? The relevant regions could be anywhere but the lower "triangles" that the network learns as outliers for some reason.

I'll experiment with different layers and/or loss for now and see if that helps

Some thing that confuses me is this: why do the labels need 2 channels if it's just binary classification?

Edit: I've run it over the weekend for 100 epochs with more training samples. the output is still the same, with just a stronger separation between the interesting parts and the "non-values". epoch_99

Edit2: The reported error rates seem very weird as well. Just looking at the prediction images and the ground truth, I wonder how this could only be ~2% error.

mimre-anl avatar Jun 08 '18 21:06 mimre-anl

Maybe you could logarithmize the input to reduce the dynamic range. The labels have to be one-hot-encoded. That's why you have 2 channels

It doesn't seem to learn anything. Have you checked if the data/label are in the correct format or if something is not working when reading in the data.

jakeret avatar Jun 11 '18 15:06 jakeret

I've checked all the data, now a couple of times and have a current run where I normalize the input data before handing it over to the SimpleDataProvider. But so far all it learns are the lower corners, and those are actually never marked in any of the masks.

e: does the mask need to be boolean? Or does it work with an int array with 0s and 1s?

mimre-anl avatar Jun 11 '18 19:06 mimre-anl

hard to tell what is going on, sorry

No it should be dtype=float32

jakeret avatar Jun 13 '18 13:06 jakeret

Got the same problem. epoch_0

The code is as below:

training_provider = image_util.ImageDataProvider(traning_path)
net = unet.Unet(channels=1, n_class=2, layers=3, features_root=32)
trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2))

path = trainer.train(training_provider, result_path, \
                     training_iters=20, epochs=1, display_step=4, \
                     prediction_path=prediction_path)

The accuracy is also weird, just like comparing the prediction with input image.

2018-06-20 16:33:47,084 Iter 16, Minibatch Loss= 0.1619, Training Accuracy= 0.9621, Minibatch error= 3.8% 2018-06-20 16:33:48,675 Epoch 0, Average loss: 0.3090, learning rate: 0.2000 2018-06-20 16:33:49,564 Verification error= 5.6%, loss= 0.2374

By the way, isn't the initial prediction output be like a random noise image? Since it haven't been trained, all parameters are random?

foothsu avatar Jun 20 '18 09:06 foothsu

I'm having the same problem with fluctuating accuracy and no convergence to a solution (whatever it is). I'm thinking that the issue is related to the dataset and i believe that there's something that's escaping to us about the image format (resolution, 8/16/32 bit depth, mask classes encoding).

I know that this neural network was implemented with a different purpose than the biomedical image segmentation, but why shouldn't work on those kind of datasets (as the U-Net architecture was presented for). So, there should be a way to make it work as we expect.

MXtreme avatar Jun 20 '18 10:06 MXtreme

I found a way to obtain good results:

  • Normalized input data (after the data provider call in the train function);
  • Added Batch Normalization;
  • U-Net with at least 5 layers and 20 epochs.

However, i keep having fluctuating accuracy. On Tensorboard i must set the smoothing option to 9.0 to see effectivetly an increasing accuracy. Does anyone have some explaination to this behavior?

MXtreme avatar Jun 23 '18 14:06 MXtreme

The default training batch_size is 1, not according to training_iters (it is possible browsing the dataset more than once in one epoch). Therefore, the accuracy may be fluctuating.

foothsu avatar Jun 24 '18 11:06 foothsu

@MXtreme Where to add Batch Normalization? I am new to DL,thanks a lot!

sanersbug avatar Aug 17 '18 09:08 sanersbug

@sanersbug , you should read the #131 issue as it's where i'have been inspired. I modified the code in the layers.py module in the conv2d function using the tensorflow.contrib.layers batch_norm function. Here is the documentation. Care on using that function. You should pay attention to the attribute "is_training" that should be true only during the training phase. By the way, you should read something about Batch Normalization to understand better why it should be placed between layers.

NOTE: I had some problems importing the tensorflow.contrib.layers using the PyCharm IDE. Take a look here if you face the same problems with that editor.

I really had good results with this solution. Good Luck.

MXtreme avatar Aug 17 '18 10:08 MXtreme

@ thanks a lot! I'll try what you say . So sorry that i don't use PyCharm , i just use Sublime Text.

sanersbug avatar Aug 18 '18 04:08 sanersbug