imax
imax copied to clipboard
[feature req.] cropping utilities
The tensorflow.image package includes some nice cropping utilities - e.g. central cropping, random cropping, etc.
Any plans to do something similar for this library? How difficult would it be to implement?
There's a simple random crop implementation here which has been handy for me. It uses jax.lax.dynamic_slice:
def random_crop(key, image, crop_sizes):
"""Crop images randomly to specified sizes.
Given an input image, it crops the image to the specified `crop_sizes`. If
`crop_sizes` are lesser than the image's sizes, the offset for cropping is
chosen at random. To deterministically crop an image,
please use `jax.lax.dynamic_slice` and specify offsets and crop sizes.
Args:
key : Key for pseudo-random number generator.
image : A JAX array which represents an image.
crop_sizes: A sequence of integers, each of which sequentially specifies the
crop size along the corresponding dimension of the image. Sequence length
must be identical to the rank of the image and the crop size should not be
greater than the corresponding image dimension.
Returns:
A cropped image, a JAX array whose shape is same as `crop_sizes`.
"""
image_shape = image.shape
assert len(image_shape) == len(crop_sizes), f"Number of image dims {len(image_shape)} and number of crop_sizes {len(crop_sizes)} do not match."
assert image_shape >= crop_sizes, f"Crop sizes {crop_sizes} should be a subset of image size {image_shape} in each dimension."
random_keys = jax.random.split(key, len(crop_sizes))
slice_starts = [
jax.random.randint(k, (), 0, img_size - crop_size + 1)
for k, img_size, crop_size in zip(random_keys, image_shape, crop_sizes)
]
out = jax.lax.dynamic_slice(image, slice_starts, crop_sizes)
return out