transformers
transformers copied to clipboard
Reverting Deta cloning mecanism.
What does this PR do?
This one is quite odd. With the revert the slow test will work (I guess what we care most about):
from transformers import AutoImageProcessor, DetaForObjectDetection
from PIL import Image
import requests
import torch
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("jozhang97/deta-swin-large")
inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0]
print(results)
However if I incorporate this:
model = DetaForObjectDetection.from_pretrained("jozhang97/deta-swin-large")
model.save_pretrained("./tmp")
model = DetaForObjectDetection.from_pretrained("./tmp")
~Then, the output is garbage again (this isn't using safetensors and is not linked to the original change). I even tried to revert the PR that introduced the bug.~
The change of output is due to safetensors. I need to thoroughly check this.
This revert will fix the slow PR anyway.
I think something is not properly setup in this model, becuase the uploaded model seems to have those layers NOT linked (hence the copy.deepcopy) but the rest of the configuration seems to supposed to assume they are, hence the issue maybe ?
Fixes https://github.com/huggingface/transformers/pull/22437#issuecomment-1500356727
Fixes # (issue)
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
The documentation is not available anymore as the PR was closed or merged.
There is however test_can_use_safetensors
failing after this PR. Is this test still relevant (at least while we keep the changes in this PR)
There is however test_can_use_safetensors failing after this PR. Is this test still relevant (at least while we keep the changes in this PR)
The new code should fix everything.
@sgugger for a new review since the change has evolved quite a bit and is not a simple revert anymore. Added inline comments in the PR to explain what's going on.
So we tried it your way and it doesn't work. Can we try to use Accelerate to detect the tied weights instead as suggested initially?
Because find_tied_weights
looks at the model, where as here we look at the state_dict, which can be passed directly to the function. In both functions the state_dict
is the source of truth, not the model, isn't it ?
We could definitely use find_tied_weights
and it would most likely pass the tests, but it wouldn't be exactly looking at the same thing. State dict is what is coming in, find_tied_weights is looking where it's being put on. (in from_pretrained, opposite in save_pretrained). In general they should be the same. But not necessarily always.
For instance, I wonder what happens for buffers.
This will ignore the whole state dict as soon as device_map="auto" or low_cpu_mem_usage=True.
Why ? It seems you're using the hash (via is
) in accelerate, I will switch to that since we want entirely shared tensors like in accelerate.
Why ? It seems you're using the hash (via is) in accelerate, I will switch to that since we want entirely shared tensors like in accelerate.
So actually hash
doesn't seem to work either, you can have shared buffer and still different hashes.
I'll try to exhibit a simple example, but deta model_decoder.class_embed.n.bias
and class_embed.n.bias
do share the buffer, and yet don't have the same hash.
This exhibits the different between find_tied_weights and the state_dict. Here the tensors from the state_dict don't share the hash, while the parameters do on the model, yet the tensors on the state dict do share memory. In this particular case, using find_tied_weights would work, but that also means the opposite is possible.
In both situations, you have access to the model, and find_tied_weights
will give you a list of names that are compatible with the state_dict
of the model.
In this particular case, using find_tied_weights would work, but that also means the opposite is possible.
If this situation (the opposite) does not appear in Transformers, let's just use find_tied_weights
.
I also would like to drive the point home that safetensors
not dealing with shared weights makes it unusable in practice in other libs: see what we have to do here... and we really want to use safetensors
. How are we going to convince other users?
makes it unusable in practice
Why are we even caring about _keys_to_ignore
and tie_weights
if it's so inconvenient ?
Why are we trying to even find tied weights in accelerate ?
How do we expect to use safetensors for the TF models, since sharing doesn't exist over there ?
In order to help with ease of use of safetensors
by itself I created this PR:
https://github.com/huggingface/safetensors/pull/236
which sorts of mimics what is done here.
However I still think this PR and the mechanism in transformer should be kept, since _keys_to_ignore
are very good at hinting which keys we should keep, and which to drop, information which is not available in safetensors
directly.
Also modification are shallower here since it doesn't touch state_dict
and load_state_dict
which the proposed methods to have to change.
Thanks for considering shared weights in
safetensors
directly. I agree it would still be cleaner to have the same kind of mechanism in Transformers. Could you please explain to me once again why the hash check does not work for the first changes in the PR (dropping weights in the checkpoint before passing it to safetensors). I don't think we ever tie weights in Transformers other than just setting the same tensors.
Mostly this: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2146
state_dict = kwargs.pop("state_dict", None)
Users can send a state_dict, not linked to self
to this PRs tried to look only at the state_dict
, instead of self
.
This is indeed a bit of an edge case.
Then there are even further edge cases:
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.Linear(100, 100)
self.b = self.a
model = Model()
assert model.a is model.b # OK !
A = torch.zeros((1000, 100))
a = A[:100]
model.a.weight = nn.Parameter(a)
model.b.weight = model.a.weight
assert model.a is model.b # Well indeed it's the same parameter, but both are shared with respect to a larger tensor
class NoSharedModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.Linear(100, 100)
self.b = torch.nn.Linear(100, 100)
model = NoSharedmodel()
A = torch.zeros((100, 100))
model.a.weight = nn.Parameter(A)
model.b.weight = nn.Parameter(A[:10])
assert model.a.weight is not model.b .weight # A is not B in parameters, however, the underlying tensors are indeed shared
I haven't looked at that deeply when fintune occurs to see if the autograd starts to copy the tensors
During state_dict()
will give back a
and b
as shared tensors, yet the params don't have the same hash.
If you want I could take a look at accelerate
shared params function and see if this applies. There's a lot of weird things
when playing super deeply with this. I discovered a lot of behavior with Deta from this PR.
But the biggest reason, really is the optional state_dict
whereas accelerate
looks directly at the model. Within from_pretrained
looking at the model is better in this case since what matters is the users' model rather than the state_dict coming from file (be it pytorch or safetensors)
Apart from that, just rebasing on main should be necessary here.
Note that I will rework the constants in future work to have one distinct key for the tied weights (as sometimes they are not tied and we are currently not warning the user if they are missing), but it's orthogonal to this PR.
Great !
Seeing the rebase, hash
doesn't work on tensors unfortunately:
import torch
A = torch.zeros((10, 10))
B = A[1]
A.untyped_storage().data_ptr() == B.untyped_storage().data_ptr()
hash(A) != hash(B)
(which will become the default utlimately)
Hurray !!!
Failing tests seem to be linked to newly release huggingface_hub==0.14.0
@sgugger Merge if you think it's OK, I'm going to not merge given this PR affects core modeling.