rust-bert
rust-bert copied to clipboard
Add MPS as default if available
Can we show some love for the Mac M1 people out there? MPS doesn't seem any harder to choose if available than CUDA, and tch-rs seems to include it in their Device enum.
I'm okay to PR myself, if anyone can suggest the right approach. Should it simply choose MPS the same way as cuda_if_available
does, and default it if available? Should we start by checking for cuda, then checking for mps, and only then defaulting to CPU?
Were you able to get MPS to actually work? I haven't been successful yet. I tried altering the device
in the model config to MPS, but I get the error Internal torch error: supported devices include CPU, CUDA and HPU, however got MPS
. When I load torch in Python it says MPS is supported, so I think I have everything installed properly.
My best guess right now is that this is related to https://github.com/pytorch/pytorch/issues/88820, where JIT models created in a certain way won't load on MPS. I'm pretty new to all this though, so not completely sure.
I have it "working", but every feature I want to use I need to set a flag that lets it fallback to cpu, which it does. The basic pytorch tensor usage executes on mps, but anything on a pipeline fails.
On Wed., Dec. 28, 2022, 2:06 p.m. Daniel Imfeld, @.***> wrote:
Were you able to get MPS to actually work? I haven't been successful yet. I tried altering the device in the model config to MPS, but I get an torch error Internal torch error: supported devices include CPU, CUDA and HPU, however got MPS. When I load torch in Python it says MPS is supported, so I think I have everything installed properly.
My best guess right now is that this is related to pytorch/pytorch#88820 https://github.com/pytorch/pytorch/issues/88820, where JIT models created in a certain way won't load on MPS. I'm pretty new to all this though, so not completely sure.
— Reply to this email directly, view it on GitHub https://github.com/guillaume-be/rust-bert/issues/311#issuecomment-1366861010, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAEPJ2KN446MPVXJ5PBN6ZTWPSFRZANCNFSM6AAAAAATH3HNCI . You are receiving this because you authored the thread.Message ID: @.***>
Ok, I actually did get it working! It's a little haphazard though. The trick is to load the VarStore
using the CPU device, and then migrate it to the GPU device. From my other research, it appears that saving that migrated VarStore to disk would then allow it to be used directly with MPS, but I haven't tried that yet.
On my M1 Pro this runs about 2-3 times as fast as with CPU/AMX.
I'm using the Sentence Embedding pipeline. Here's the relevant change there:
let transformer =
SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?;
var_store.load(transformer_weights_resource.get_local_path()?)?;
+ var_store.set_device(tch::Device::Mps);
Is there any way to access var_store
after initializing the default model? Using the zero-shot classification pipeline.
Currently, I don't think so. I ended up just copying the pipeline code into my own project and modifying it for my purposes.
Apologies I don't have a MPS device at hand for testing - it seems the issue is that creating the model passing Device::Mps
fails, but setting it to Device::Cpu
for loading weights and then changing the VarStore
device works.
It would make sense to raise an issue with the upstream tch-rs as to my understanding it should be possible to work with Mps
the same way as with Cuda
.
Would accessing the var_store
of pipelines via .get_mut_var_store(&mut self) -> &mut VarStore
interface help for your usecase for the time being?
That would help, yes! Any way to access to the underlying var_store without having to rewrite the whole initialization would work.
@dimfeld @chrisvander I have opened a fix on the upstream library (https://github.com/LaurentMazare/tch-rs/pull/623) - if you have the time it would be great if you could perform some testing and see if this addresses the issue.
Hmm... I did make it to a run with MPS, but it seems to be ignoring the PYTORCH_ENABLE_MPS_FALLBACK=1
environment variable and is failing with TorchError "sparse_grad not supported in MPS yet". Should fallback to CPU there. Any tips?
let config = ZeroShotClassificationConfig {
device: Device::Mps,
..Default::default()
};
let model = ZeroShotClassificationModel::new(config).unwrap();
Hello @chrisvander ,
Apologies for the late response. I am pushing some changes that set the sparse_grad
parameter to false across the library to improve device compatibility (https://github.com/guillaume-be/rust-bert/pull/404) that should solve the issue