tf_unet
tf_unet copied to clipboard
cannot identify image file '/home/tf_unet/data/train/47.tif'
Number of files used: 52
Traceback (most recent call last):
File "myTrain.py", line 9, in
my image file have 4 channels including r,g,b,nir ,so how to solve the problem?
I know the function Image.open cannot deal with 'tif ' ,but how to deal with the 'tif' type??
thanks a lot.
You could write your own DataProvider by subclassing the ImageDataProvider and overwriting the _load_file function
@jakeret Thank you very much !
When i train , it return this error
Traceback (most recent call last):
File "myTrain.py", line 13, in
The follow is the image_util.py i changed.
from future import print_function, division, absolute_import, unicode_literals #import cv2 import glob import numpy as np from PIL import Image from osgeo import gdal
class BaseDataProvider(object): channels = 4 n_class = 1
def __init__(self, a_min=None, a_max=None):
self.a_min = a_min if a_min is not None else -np.inf
self.a_max = a_max if a_min is not None else np.inf
def _load_data_and_label(self):
data, label = self._next_data()
#print('data',data,'label',label)
train_data = self._process_data(data)
labels = self._process_labels(label)
train_data, labels = self._post_process(train_data, labels)
nx = train_data.shape[1]
ny = train_data.shape[0]
#return train_data.reshape(1, ny, nx, self.channels), labels.reshape(1, ny, nx, self.n_class)
return train_data.reshape(1, ny, nx, 4), labels.reshape(1, ny, nx, 1)
def _process_labels(self, label):
if self.n_class == 2:
nx = label.shape[1]
ny = label.shape[0]
labels = np.zeros((ny, nx, self.n_class), dtype=np.float32)
labels[..., 1] = label
labels[..., 0] = ~label
return labels
return label
def _process_data(self, data):
# normalization
data = np.clip(np.fabs(data), self.a_min, self.a_max)
data -= np.amin(data)
data /= np.amax(data)
return data
def _post_process(self, data, labels):
"""
Post processing hook that can be used for data augmentation
:param data: the data array
:param labels: the label array
"""
return data, labels
def __call__(self, n):
train_data, labels = self._load_data_and_label()
nx = train_data.shape[1]
ny = train_data.shape[2]
X = np.zeros((n, nx, ny, self.channels))
Y = np.zeros((n, nx, ny, self.n_class))
X[0] = train_data
Y[0] = labels
for i in range(1, n):
train_data, labels = self._load_data_and_label()
X[i] = train_data
Y[i] = labels
return X, Y
class SimpleDataProvider(BaseDataProvider): def init(self, data, label, a_min=None, a_max=None, channels=1, n_class = 2): super(SimpleDataProvider, self).init(a_min, a_max) self.data = data self.label = label self.file_count = data.shape[0] self.n_class = n_class self.channels = channels
def _next_data(self):
idx = np.random.choice(self.file_count)
return self.data[idx], self.label[idx]
class ImageDataProvider(BaseDataProvider): def init(self, search_path, a_min=None, a_max=None, data_suffix=".tif", mask_suffix='_mask.tif', shuffle_data=True, n_class = 2): super(ImageDataProvider, self).init(a_min, a_max) self.data_suffix = data_suffix self.mask_suffix = mask_suffix self.file_idx = -1 self.shuffle_data = shuffle_data self.n_class = n_class
self.data_files = self._find_data_files(search_path)
if self.shuffle_data:
np.random.shuffle(self.data_files)
assert len(self.data_files) > 0, "No training files"
print("Number of files used: %s" % len(self.data_files))
img = self._load_file(self.data_files[0])
self.channels = 1 if len(img.shape) == 2 else img.shape[-1]
def _find_data_files(self, search_path):
all_files = glob.glob(search_path)
return [name for name in all_files if self.data_suffix in name and not self.mask_suffix in name]
def _load_file(self, path, dtype=np.float32):
driver = gdal.GetDriverByName('GTiff')
driver.Register()
dataset = gdal.Open(path)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
return dataset.ReadAsArray(0,0,im_width,im_height)
#return np.array(Image.open(path), dtype)
# return np.squeeze(cv2.imread(image_name, cv2.IMREAD_GRAYSCALE))
def _cylce_file(self):
self.file_idx += 1
if self.file_idx >= len(self.data_files):
self.file_idx = 0
if self.shuffle_data:
np.random.shuffle(self.data_files)
def _next_data(self):
self._cylce_file()
image_name = self.data_files[self.file_idx]
label_name = image_name.replace(self.data_suffix, self.mask_suffix)
img = self._load_file(image_name, np.float32)
label = self._load_file(label_name, np.bool)
return img,label
The error message seems to be missing. There is only the traceback I'm also not quite sure what you changed. This looks like an entire copy of the tf_unet image_util.py..
@jakeret OK,thank you , Finally, I changed ImageDataProvider and overwriting the _load_file function, it's run success ,but in the 'util.py' , it has some error.
The main question is in the function 'combine_img_prediction'
I changed it and print some parameters, it show in the terminal like:
to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch) (3840, 960, 4)
to_rgb(crop_to_shape(gt, pred.shape).reshape(-1, ny, 1)) (3840, 960, 3)
/home/tf_unet/tf_unet/util.py:74: RuntimeWarning: invalid value encountered in true_divide img /= np.amax(img)
to_rgb(pred.reshape(-1, ny, 1))) (3840, 960, 3)
Traceback (most recent call last):
File "myTrain.py", line 13, in
def combine_img_prediction(data, gt, pred): """ Combines the data, grouth thruth and the prediction into one rgb image
:param data: the data tensor
:param gt: the ground thruth tensor
:param pred: the prediction tensor
:returns img: the concatenated rgb image
"""
ny = pred.shape[2]
ch = data.shape[3]
#print('pred.shape: ', pred.shape)
print('ny,ch: ',ny,ch)
#print('data.shape: ', data.shape)
#print('gt.shape: ',gt.shape)
print('to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch)',to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch)).shape)
print('to_rgb(crop_to_shape(gt, pred.shape).reshape(-1, ny, 1))',to_rgb(crop_to_shape(gt, pred.shape).reshape(-1, ny, 1)).shape)
print('to_rgb(pred.reshape(-1, ny, 1)))',to_rgb(pred.reshape(-1, ny, 1)).shape)
img = np.concatenate((to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch)),
to_rgb(crop_to_shape(gt, pred.shape).reshape(-1, ny, 1)),
to_rgb(pred.reshape(-1, ny, 1))), axis=1)
#to_rgb(crop_to_shape(gt[..., 1], pred.shape).reshape(-1, ny, 1))a
#to_rgb(pred[..., 1].reshape(-1, ny, 1))), axis=1)
return img
my train data have 4 channels (r,g,b,nir),but it seems like the function cannot deal with it !