transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[Vision] Support different floating precision inputs from the `ImageProcessor`

Open younesbelkada opened this issue 2 years ago • 7 comments

What does this PR do?

This PR introduces the input casting mechanism for image processors. Since the introduction of accelerate supported models for Vision, I have been playing around with half-precision models. I found it a bit inintuitive to manually cast the pixel_values outside the ImageProcessor class. Therefore for some models, small hacks have been introduced to make the casting operation more user-friendly. With this PR, it will be possible to cast the input tensors to any floating point precision, for any framework, at theImageProcessor level as follows:

from transformers import ViTFeatureExtractor
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-large-patch32-384')
inputs = feature_extractor(images=image, return_tensors="np", float_precision="float16")
print(inputs.pixel_values.dtype)
>>> float16

The casting discards non-floating point tensors, therefore these tensors should not be affected by the casting mechanism (thinking for eg for ViLT that takes both text + image)

With this PR, the hacks introduced on ViT and OWLViT will be removed!

cc @amyeroberts @ydshieh

younesbelkada avatar Nov 25 '22 16:11 younesbelkada

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

I think this PR is ready, at least as a PoC! To make the PR complete, for now the arg float_precision needs to be manually added for each image processor. Before moving forward and start doing it for all image processors and adding tests, I would love to hear from @sgugger, @amyeroberts & @ydshieh to see if this is the approach we would like to follow! Thanks again!

younesbelkada avatar Nov 29 '22 10:11 younesbelkada

Thanks so much everyone for your comments! After thinking a bit and trying to see if this could be useful for flax

import jax.numpy as jnp
from transformers import FlaxViTForImageClassification, ViTFeatureExtractor

from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

model = FlaxViTForImageClassification.from_pretrained("google/vit-base-patch16-224", dtype=jnp.float16)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

inputs = feature_extractor(images=image, return_tensors="np")
outputs = model(**inputs)
print(outputs)

it seems that flax can deal properly with different dtype, without having to explicitly cast the input. I think that a good point has been raised by @sgugger, however it could be useful if it is needed on tf side. If not, happy to change the PR to something that modifies only the .to function as this will be intended only for PyTorch.

younesbelkada avatar Dec 01 '22 09:12 younesbelkada

I don't have strong opinion though. So you can follow what @sgugger suggests. If we find it's useful for other frameworks, we can add them back.

ydshieh avatar Dec 01 '22 11:12 ydshieh

Thanks everyone! Let's keep this PR open in case we figure out this is needed for tf. I have opened a PR in #20536 for supporting dtypes in .to

younesbelkada avatar Dec 02 '22 10:12 younesbelkada

@gante @Rocketknight1 - how useful would this be in TF land?

I don't think our TF models are compatible with half-precision, right @Rocketknight1? At least I haven't used TF with half-precision :D

gante avatar Dec 16 '22 17:12 gante

Extremely late reply on the TF front, but yeah, we aren't really running TF models in half precision right now. We do support mixed precision (similar to Torch AMP), but we don't officially support splatting the whole model to (b)float16 yet.

Rocketknight1 avatar Jan 19 '23 14:01 Rocketknight1

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 03 '23 15:04 github-actions[bot]