[feature] Augmentation function over a collection of images stored in a Numpy array.
🚀 Feature
I would like to apply a single augmentation function over a collection of images stored in a Numpy array. I tried passing an (K,L,L) array of K black and white images of size LxL to aug_np_wrapper, but it doesn't seem to work.
Motivation
I'm working on a project that would require transform all of the images in a single dataset, e.g., Fasion MNIST.
Pitch
I would like to be able to apply a transformation over a given axis for a collection of images.
Alternatives
The only alternative I can consider at the moment is using a for-loop, but this will be quite inefficient for I would like to obtain.
Hi @gerdm, thank you for this suggestion! Indeed our aug_np_wrapper (and all of our image augmentation functions in general) only expect single images as input for now. We could add support to aug_np_wrapper for taking in a batch of images. However, please note that our image augmentations are not implemented in numpy under the hood (aug_np_wrapper is just a wrapper which converts the numpy array to a PIL image and then calls the augmentation). Thus we will have to either use a for loop, which as you said is not very efficient, or we can get some speed-ups e.g. by using multiprocessing.
Let me know what you think, or feel free to try multiprocessing on your side to see if this unblocks you. I will add this to our backlog of tasks and will link the PR here when I get to it :)
Hi @zpapakipos!
I went with the multiprocessing option you outlined. I'll paste the code here in case someone wants to try this in the future. The code below makes use of image.blur, but it's pretty easy to generalise to other methods.
First, we create a class that defines a callable to pass to Python multiprocessing
import numpy as np
from multiprocessing import Pool
from augly import image
class BlurRad:
def __init__(self, rad):
self.rad = rad
def __call__(self, img):
return self.blur_multiple(img)
def blur(self, X):
"""
Blur an image using the augly library
Paramters
---------
X: np.array
A single NxM-dimensional array
radius: float
The amout of blurriness
"""
return image.aug_np_wrapper(X, image.blur, radius=self.rad)
def blur_batch(self, X_batch):
images_out = []
for X in X_batch:
img_blur = self.blur(X)
images_out.append(img_blur)
images_out = np.stack(images_out, axis=0)
return images_out
We can then use of Python multiprocessing to blur a collection of images using a single radius.
def proc_dataset(img_dataset, radius, n_processes):
"""
Blur all images of a dataset stored in a numpy array.
Parameters
----------
radius: float
Intensity of bluriness
img_dataset: array(N, L, K)
N images of size LxK
n_processes: int
Number of processes to blur over
"""
with Pool(processes=n_processes) as pool:
dataset_proc = np.array_split(img_dataset, n_processes)
dataset_proc = pool.map(BlurRad(radius), dataset_proc)
dataset_proc = np.concatenate(dataset_proc, axis=0)
n_obs = len(img_dataset)
dataset_proc = dataset_proc.reshape(n_obs, -1)
return dataset_proc
If we want to blur different images over different radii, we define the following function
def blur_multiple(radii, img_dataset):
"""
Blur every element of `img_dataset` given an element
of `radii`.
"""
imgs_out = []
for radius, img in zip(radii, img_dataset):
img_proc = BlurRad(radius).blur(img)
imgs_out.append(img_proc)
imgs_out = np.stack(imgs_out, axis=0)
return imgs_out
def proc_dataset_multiple(img_dataset, radii, n_processes):
"""
Blur all images of a dataset stored in a numpy array with variable
radius.
Parameters
----------
radius: array(N,) or float
Intensity of bluriness. One per image. If
float, the same value is used for all images.
img_dataset: array(N, L, K)
N images of size LxK
n_processes: int
Number of processes to blur over
"""
if type(radii) in [float, np.float_]:
radii = radii * np.ones(len(img_dataset))
with Pool(processes=n_processes) as pool:
dataset_proc = np.array_split(img_dataset, n_processes)
radii_split = np.array_split(radii, n_processes)
elements = zip(radii_split, dataset_proc)
dataset_proc = pool.starmap(blur_multiple, elements)
dataset_proc = np.concatenate(dataset_proc, axis=0)
return dataset_proc