transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Update feature extractor methods to type cast before normalize

Open amyeroberts opened this issue 3 years ago • 1 comments

What does this PR do?

At the moment, the return type of our feature extractors isn't always as expected or sometimes fails if a do_xxx config flag is set to False. This PR introduces the necessary changes to the ImageFeatureExtractionMixin methods such that we can modify the feature extractor calls to fix this. This is an alternative solution to setting return_tensors="np" as default.

Each vision model using ImageFeatureExtractionMixin has a separate PR adding their necessary modifications and tests.

Details

At the moment, if do_normalize=False, do_resize=True and return_tensors=None then the output tensors will be a list of PIL.Image.Image objects if even if the inputs are numpy arrays. If do_normalize=False and return_tensors is specified ("pt", "np", "tf", "jax") an exception is raised.

The main reasons for this are:

  • BatchFeature can't convert PIL.Image.Image to the requested tensors.
  • The necessary conversion of PIL.Image.Image -> np.ndarray happens within the normalize method and the output of resize is PIL.Image.Image.

In order to have the type of the returned pixel_values reflect return_tensors we need to:

  • Convert PIL.Image.Image objects to numpy arrays before passing to BatchFeature
  • Be able to optionally rescale the inputs in the normalize method. If the input to normalize is a PIL.Image.Image it is converted to a numpy array using to_numpy_array which rescales to between [0, 1]. If do_resize=False then this rescaling won't happen if the inputs are numpy arrays.

The optional flags enable us to preserve the same default behaviour for the resize and normalize methods whilst modifying the internal logic of the feature extractor call.

Checks

The model PRs are all cherry picked (file diffs) of type-cast-before-normalize

The following was run to check the outputs:

from dataclasses import dataclass

import requests
import numpy as np
from PIL import Image
import pygit2
from transformers import AutoFeatureExtractor

@dataclass
class FeatureExtractorConfig:
    model_name: str
    checkpoint: str
    return_type: str = "np"
    feat_name: str = "pixel_values"

IMAGE_FEATURE_EXTRACTOR_CONFIGS = [
    FeatureExtractorConfig(model_name="clip", checkpoint="openai/clip-vit-base-patch32"),
    FeatureExtractorConfig(model_name="convnext", checkpoint="facebook/convnext-tiny-224"),
    FeatureExtractorConfig(model_name="deit", checkpoint="facebook/deit-base-distilled-patch16-224"),
    FeatureExtractorConfig(model_name="detr", checkpoint="facebook/detr-resnet-50"),
    FeatureExtractorConfig(model_name="dpt", checkpoint="Intel/dpt-large"),
    FeatureExtractorConfig(model_name="flava", checkpoint="facebook/flava-full"),
    FeatureExtractorConfig(model_name="glpn", checkpoint="vinvino02/glpn-kitti"),
    FeatureExtractorConfig(model_name="imagegpt", checkpoint="openai/imagegpt-small", feat_name='input_ids'),
    FeatureExtractorConfig(model_name="layoutlmv2", checkpoint="microsoft/layoutlmv2-base-uncased"),
    FeatureExtractorConfig(model_name="layoutlmv3", checkpoint="microsoft/layoutlmv3-base"),
    FeatureExtractorConfig(model_name="levit", checkpoint="facebook/levit-128S"),
    FeatureExtractorConfig(model_name="maskformer", checkpoint="facebook/maskformer-swin-base-ade", return_type="pt"),
    FeatureExtractorConfig(model_name="mobilevit", checkpoint="apple/mobilevit-small"),
    FeatureExtractorConfig(model_name="owlvit", checkpoint="google/owlvit-base-patch32"),
    FeatureExtractorConfig(model_name="perceiver", checkpoint="deepmind/vision-perceiver-fourier"),
    FeatureExtractorConfig(model_name="poolformer", checkpoint="sail/poolformer_s12"),
    FeatureExtractorConfig(model_name="segformer", checkpoint="nvidia/mit-b0"),
    FeatureExtractorConfig(model_name="vilt", checkpoint="dandelin/vilt-b32-mlm"),
    FeatureExtractorConfig(model_name="vit", checkpoint="google/vit-base-patch16-224-in21k"),
    FeatureExtractorConfig(model_name="yolos", checkpoint="hustvl/yolos-small"),
]

VIDEO_FEATURE_EXTRACTOR_CONFIGS = [
	FeatureExtractorConfig(model_name="videomae", checkpoint="MCG-NJU/videomae-base"),
]

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

def produce_pixel_value_outputs():
    BRANCH = pygit2.Repository('.').head.shorthand

    def get_processed_outputs(inputs, model_checkpoint, feat_name):
        feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
        outputs = feature_extractor(inputs, return_tensors=fe_config.return_type)[feat_name]
        return outputs

    for fe_config in IMAGE_FEATURE_EXTRACTOR_CONFIGS:
        print(fe_config.model_name, fe_config.checkpoint)
        outputs = get_processed_outputs(image, fe_config.checkpoint, fe_config.feat_name)
        np.save(f"{fe_config.model_name}_{BRANCH.replace('-', '_')}_pixel_values.npy", outputs)

    for fe_config in VIDEO_FEATURE_EXTRACTOR_CONFIGS:
        print(fe_config.model_name, fe_config.checkpoint)
        outputs = get_processed_outputs([[image, image]], fe_config.checkpoint, fe_config.feat_name)
        np.save(f"{fe_config.model_name}_{BRANCH.replace('-', '_')}_pixel_values.npy", outputs)

branch_main = "main"
branch_feature = "type-cast-before-normalize"

repo = pygit2.Repository('.git')

print("\nChecking out main")
branch = repo.lookup_branch('main')
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)

produce_pixel_value_outputs()

print("\nChecking out type-cast-before-normalize")
branch = repo.lookup_branch('type-cast-before-normalize')
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)

produce_pixel_value_outputs()

for fe_config in IMAGE_FEATURE_EXTRACTOR_CONFIGS + VIDEO_FEATURE_EXTRACTOR_CONFIGS:
    model_name = fe_config.model_name

    try:
        output_1 = np.load(f"{model_name}_{branch_main}_pixel_values.npy")
        output_2 = np.load(f"{model_name}_{branch_feature.replace('-', '_')}_pixel_values.npy")

        max_diff = np.amax(np.abs(output_1 - output_2))
        print(f"{model_name}: {max_diff:.5f}")
    except Exception as e:
        print(f"{model_name} failed check with {e}")

Output:

clip: 0.00000
convnext: 0.00000
deit: 0.00000
detr: 0.00000
dpt: 0.00000
flava: 0.00000
glpn: 0.00000
imagegpt: 0.00000
layoutlmv2: 0.00000
layoutlmv3: 0.00000
levit: 0.00000
maskformer: 0.00000
mobilevit: 0.00000
owlvit: 0.00000
perceiver: 0.00000
poolformer: 0.00000
segformer: 0.00000
vilt: 0.00000
vit: 0.00000
yolos: 0.00000
videomae: 0.00000

Fixes

https://github.com/huggingface/transformers/issues/17714 https://github.com/huggingface/transformers/issues/15055

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [x] Did you write any new necessary tests? (in model PRs)

amyeroberts avatar Aug 05 '22 21:08 amyeroberts

The documentation is not available anymore as the PR was closed or merged.

Looks good to me! If the changes per model are small enough, it would probably be best to change them all in the same PR, rather than doing individual ones.

@sgugger Yep, I completely agree. The changes all together aren't that small, but almost exactly the same across models. Once this is merged in, I'll open a PR for the VideoMAE refactor (https://github.com/amyeroberts/transformers/pull/9/files) as this covers all the changes. Once approved, I'll merge in the other models to the branch, as for re-review of the total PR and then merge all together.

amyeroberts avatar Aug 11 '22 11:08 amyeroberts