segmentation_models
segmentation_models copied to clipboard
Unet with grayscale Image not working.
I am working on MRI dataset where the input and mask is a greyscale image. How to train the model for this particular case? and also if I don't want to use backbone how to do it.
I think you technically need a backbone, which acts as the encoder for the network. If you want to start from scratch, you could just load any backbone (like VGG16 by default), just not load the imagenet weights with it, and then set the entire model to trainable?
@JordanMakesMaps What if I don;t want to use encoder. I just want to use the unet model without using pretrained encoders. Just a simple unet without a backbone. Because here the unet is accepting channel 3 input image.
@ninjakx so if you don't want to use a pre-trained encoder, just don't load the imagenet weights. But by definition a Unet architecture consists of an encoder (which is one of the backbones) followed by a decoder which is a sort of mirror image of the encoder itself but in reverse doing upsampling operations instead of downsampling.
But if you're looking for some Unet architecture whose encoder is not a defined convolutional base (VGG, ResNet, DenseNet, EfficientNet), try checking out this repo, but the results are no where near as good as the ones from this repo.
This repo should allow you to use images that are not three channels?
# from the readme
# if you set input channels not equal to 3, you have to set encoder_weights=None
# how to handle such case with encoder_weights='imagenet' described in docs
model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)
@JordanMakesMaps you're very helpful as usual, thanks a lot for the answer!
@JordanMakesMaps I am currently working on the a dataset of grayscale images and i used no pre-trained model as a backbone when i am running my pipeline it executes without an error but the predicted mask images being saved are all black and don't show anything. I am very Confused about what should be my next step. Please help me out
Hi @ishreyaa07, walk me through how you're loading these images up as a dataset. Are they original 8-bit (0-255), are you stacking them so the dimensions are (w x h x 3)?
After you've trained the model and are making predictions, what's your process? How are you displaying the results and saving them. Are they actually all black (0's)? Please provide some code, otherwise there's not much I can help with.