tf_unet icon indicating copy to clipboard operation
tf_unet copied to clipboard

cannot identify image file '/home/tf_unet/data/train/47.tif'

Open sanersbug opened this issue 7 years ago • 4 comments

Number of files used: 52 Traceback (most recent call last): File "myTrain.py", line 9, in data_provider = image_util.ImageDataProvider(search_path, a_min=0, a_max=255, data_suffix='.tif', mask_suffix='_V1_poly.tif') File "/home/tf_unet/tf_unet/image_util.py", line 169, in init img = self._load_file(self.data_files[0]) File "/home/tf_unet/tf_unet/image_util.py", line 178, in _load_file return np.array(Image.open(path), dtype) File "/home/anaconda3/lib/python3.6/site-packages/PIL/Image.py", line 2590, in open % (filename if filename else fp))

my image file have 4 channels including r,g,b,nir ,so how to solve the problem?

I know the function Image.open cannot deal with 'tif ' ,but how to deal with the 'tif' type??

thanks a lot.

sanersbug avatar Jul 24 '18 02:07 sanersbug

You could write your own DataProvider by subclassing the ImageDataProvider and overwriting the _load_file function

jakeret avatar Jul 24 '18 18:07 jakeret

@jakeret Thank you very much !

When i train , it return this error

Traceback (most recent call last): File "myTrain.py", line 13, in path = trainer.train(data_provider, './unet_trained', training_iters=32, epochs=20, display_step=2) File "/home/tf_unet/tf_unet/unet.py", line 422, in train test_x, test_y = data_provider(self.verification_batch_size) File "/home/tf_unet/tf_unet/image_util.py", line 92, in call train_data, labels = self._load_data_and_label() File "/home/tf_unet/tf_unet/image_util.py", line 62, in _load_data_and_label return train_data.reshape(1, ny, nx, 4), labels.reshape(1, ny, nx, 1)

The follow is the image_util.py i changed.

from future import print_function, division, absolute_import, unicode_literals #import cv2 import glob import numpy as np from PIL import Image from osgeo import gdal

class BaseDataProvider(object): channels = 4 n_class = 1

def __init__(self, a_min=None, a_max=None):
    self.a_min = a_min if a_min is not None else -np.inf
    self.a_max = a_max if a_min is not None else np.inf

def _load_data_and_label(self):
    data, label = self._next_data()
    #print('data',data,'label',label)
    train_data = self._process_data(data)
    labels = self._process_labels(label)

    train_data, labels = self._post_process(train_data, labels)

    nx = train_data.shape[1]
    ny = train_data.shape[0]

    #return train_data.reshape(1, ny, nx, self.channels), labels.reshape(1, ny, nx, self.n_class)
    return train_data.reshape(1, ny, nx, 4), labels.reshape(1, ny, nx, 1)

def _process_labels(self, label):
    if self.n_class == 2:
        nx = label.shape[1]
        ny = label.shape[0]
        labels = np.zeros((ny, nx, self.n_class), dtype=np.float32)
        labels[..., 1] = label
        labels[..., 0] = ~label
        return labels

    return label

def _process_data(self, data):
    # normalization
    data = np.clip(np.fabs(data), self.a_min, self.a_max)
    data -= np.amin(data)
    data /= np.amax(data)
    return data

def _post_process(self, data, labels):
    """
    Post processing hook that can be used for data augmentation
    
    :param data: the data array
    :param labels: the label array
    """
    return data, labels

def __call__(self, n):
    train_data, labels = self._load_data_and_label()
    nx = train_data.shape[1]
    ny = train_data.shape[2]

    X = np.zeros((n, nx, ny, self.channels))
    Y = np.zeros((n, nx, ny, self.n_class))

    X[0] = train_data
    Y[0] = labels
    for i in range(1, n):
          train_data, labels = self._load_data_and_label()
        X[i] = train_data
        Y[i] = labels

    return X, Y

class SimpleDataProvider(BaseDataProvider): def init(self, data, label, a_min=None, a_max=None, channels=1, n_class = 2): super(SimpleDataProvider, self).init(a_min, a_max) self.data = data self.label = label self.file_count = data.shape[0] self.n_class = n_class self.channels = channels

def _next_data(self):
    idx = np.random.choice(self.file_count)
    return self.data[idx], self.label[idx]

class ImageDataProvider(BaseDataProvider): def init(self, search_path, a_min=None, a_max=None, data_suffix=".tif", mask_suffix='_mask.tif', shuffle_data=True, n_class = 2): super(ImageDataProvider, self).init(a_min, a_max) self.data_suffix = data_suffix self.mask_suffix = mask_suffix self.file_idx = -1 self.shuffle_data = shuffle_data self.n_class = n_class

    self.data_files = self._find_data_files(search_path)
                                                  
    if self.shuffle_data:
        np.random.shuffle(self.data_files)

    assert len(self.data_files) > 0, "No training files"
    print("Number of files used: %s" % len(self.data_files))

    img = self._load_file(self.data_files[0])
    self.channels = 1 if len(img.shape) == 2 else img.shape[-1]

def _find_data_files(self, search_path):
    all_files = glob.glob(search_path)
    return [name for name in all_files if self.data_suffix in name and not self.mask_suffix in name]


def _load_file(self, path, dtype=np.float32):
    driver = gdal.GetDriverByName('GTiff')
    driver.Register()
    dataset = gdal.Open(path)
    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize
    return dataset.ReadAsArray(0,0,im_width,im_height)
    #return np.array(Image.open(path), dtype)
    # return np.squeeze(cv2.imread(image_name, cv2.IMREAD_GRAYSCALE))

def _cylce_file(self):
    self.file_idx += 1
    if self.file_idx >= len(self.data_files):
        self.file_idx = 0
        if self.shuffle_data:
            np.random.shuffle(self.data_files)

def _next_data(self):
    self._cylce_file()
    image_name = self.data_files[self.file_idx]
    label_name = image_name.replace(self.data_suffix, self.mask_suffix)

    img = self._load_file(image_name, np.float32)
    label = self._load_file(label_name, np.bool)

    return img,label

sanersbug avatar Jul 25 '18 07:07 sanersbug

The error message seems to be missing. There is only the traceback I'm also not quite sure what you changed. This looks like an entire copy of the tf_unet image_util.py..

jakeret avatar Jul 26 '18 06:07 jakeret

@jakeret OK,thank you , Finally, I changed ImageDataProvider and overwriting the _load_file function, it's run success ,but in the 'util.py' , it has some error.

The main question is in the function 'combine_img_prediction'

I changed it and print some parameters, it show in the terminal like:

to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch) (3840, 960, 4)

to_rgb(crop_to_shape(gt, pred.shape).reshape(-1, ny, 1)) (3840, 960, 3)

/home/tf_unet/tf_unet/util.py:74: RuntimeWarning: invalid value encountered in true_divide img /= np.amax(img)

to_rgb(pred.reshape(-1, ny, 1))) (3840, 960, 3)

Traceback (most recent call last): File "myTrain.py", line 13, in path = trainer.train(data_provider, './unet_trained', training_iters=32, epochs=20, display_step=2) File "/home/tf_unet/tf_unet/unet.py", line 423, in train pred_shape = self.store_prediction(sess, test_x, test_y, "_init") File "/home/tf_unet/tf_unet/unet.py", line 475, in store_prediction img = util.combine_img_prediction(batch_x, batch_y, prediction) File "/home/tf_unet/tf_unet/util.py", line 111, in combine_img_prediction to_rgb(pred.reshape(-1, ny, 1))), axis=1) ValueError: all the input array dimensions except for the concatenation axis must match exactly

def combine_img_prediction(data, gt, pred): """ Combines the data, grouth thruth and the prediction into one rgb image

:param data: the data tensor
:param gt: the ground thruth tensor
:param pred: the prediction tensor

:returns img: the concatenated rgb image 
"""
ny = pred.shape[2]
ch = data.shape[3]
#print('pred.shape: ', pred.shape)
print('ny,ch: ',ny,ch)
#print('data.shape: ', data.shape)
#print('gt.shape: ',gt.shape)
print('to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch)',to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch)).shape)
print('to_rgb(crop_to_shape(gt, pred.shape).reshape(-1, ny, 1))',to_rgb(crop_to_shape(gt, pred.shape).reshape(-1, ny, 1)).shape)
print('to_rgb(pred.reshape(-1, ny, 1)))',to_rgb(pred.reshape(-1, ny, 1)).shape)
img = np.concatenate((to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch)),
                      to_rgb(crop_to_shape(gt, pred.shape).reshape(-1, ny, 1)),
                      to_rgb(pred.reshape(-1, ny, 1))), axis=1)
#to_rgb(crop_to_shape(gt[..., 1], pred.shape).reshape(-1, ny, 1))a
#to_rgb(pred[..., 1].reshape(-1, ny, 1))), axis=1)
return img

my train data have 4 channels (r,g,b,nir),but it seems like the function cannot deal with it !

sanersbug avatar Jul 26 '18 07:07 sanersbug