optimum
optimum copied to clipboard
Flava model better transformers
What does this PR do?
Fixes #20372
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [X] Did you make sure to update the documentation with your changes?
- [X] Did you write any new necessary tests?
This PR replaces an unfinished PR that contains outdated code and irrelevant file edits. Flava test script result:
FlavaModel(
(text_model): FlavaTextModel(
(embeddings): FlavaTextEmbeddings(
(word_embeddings): Embedding(1124, 32, padding_idx=0)
(position_embeddings): Embedding(512, 32)
(token_type_embeddings): Embedding(2, 32)
(LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(encoder): FlavaEncoder(
(layer): ModuleList(
(0): FlavaLayerBetterTransformer()
(1): FlavaLayerBetterTransformer()
(2): FlavaLayerBetterTransformer()
(3): FlavaLayerBetterTransformer()
(4): FlavaLayerBetterTransformer()
)
)
(layernorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(pooler): FlavaPooler(
(dense): Linear(in_features=32, out_features=32, bias=True)
(activation): Tanh()
)
)
(image_model): FlavaImageModel(
(embeddings): FlavaImageEmbeddings(
(patch_embeddings): PatchEmbeddings(
(projection): Conv2d(3, 32, kernel_size=(2, 2), stride=(2, 2))
)
(dropout): Dropout(p=0.0, inplace=False)
)
(encoder): FlavaEncoder(
(layer): ModuleList(
(0): FlavaLayerBetterTransformer()
(1): FlavaLayerBetterTransformer()
(2): FlavaLayerBetterTransformer()
(3): FlavaLayerBetterTransformer()
(4): FlavaLayerBetterTransformer()
)
)
(layernorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(pooler): FlavaPooler(
(dense): Linear(in_features=32, out_features=32, bias=True)
(activation): Tanh()
)
)
(multimodal_model): FlavaMultimodalModel(
(encoder): FlavaEncoder(
(layer): ModuleList(
(0): FlavaLayerBetterTransformer()
(1): FlavaLayerBetterTransformer()
(2): FlavaLayerBetterTransformer()
(3): FlavaLayerBetterTransformer()
(4): FlavaLayerBetterTransformer()
)
)
(layernorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(pooler): FlavaPooler(
(dense): Linear(in_features=32, out_features=32, bias=True)
(activation): Tanh()
)
)
(image_projection): Linear(in_features=32, out_features=32, bias=True)
(text_projection): Linear(in_features=32, out_features=32, bias=True)
(image_to_mm_projection): Linear(in_features=32, out_features=32, bias=True)
(text_to_mm_projection): Linear(in_features=32, out_features=32, bias=True)
)
To: @younesbelkada @michaelbenayoun Please let me know if I need to correct anything.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
Hi @katiele47 would like to have it merged? If so, let's resolve the conflict and I can trigger the CI.
This PR has been marked as stale because it has been open for 90 days with no activity. This thread will be automatically closed in 30 days if no further activity occurs.