tuned-lens icon indicating copy to clipboard operation
tuned-lens copied to clipboard

Make lenses device agnostic using map_location in torch.load()

Open Neelectric opened this issue 1 year ago • 3 comments

Some .pt files on the huggingface space (for eg this one, looks like all pythia lenses might be affected? gpt2-large seems to work) contain "cuda:0". See the following screenshot (forgive me for opening a .pt file in vim, it won't happen again):

image

This means that when th.load() gets called and the Unpickling occurs, it tries to load this onto "cuda:0" regardless of the torch backend on the machine, or the device the model itself is loaded on. This is a shame, as Mac users relying on MPS won't be able to load these lenses.

The 2-line change in this commit checks what device the loaded embedding is on, and then loads the state (and as such the .pt file) onto this same device using map_location. I think this is quite an elegant solution that gives the user the flexibility to load their model onto the device of their choosing, and copy the lens onto the same device. Please let me know if another solution would be preferable, and do forgive me if I have gone against any conventions of contributing to OSS. This is my first pull request ever, I mean well!

Neelectric avatar Oct 28 '24 18:10 Neelectric