rust-bert icon indicating copy to clipboard operation
rust-bert copied to clipboard

Add MPS as default if available

Open jkoudys opened this issue 2 years ago • 10 comments

Can we show some love for the Mac M1 people out there? MPS doesn't seem any harder to choose if available than CUDA, and tch-rs seems to include it in their Device enum.

I'm okay to PR myself, if anyone can suggest the right approach. Should it simply choose MPS the same way as cuda_if_available does, and default it if available? Should we start by checking for cuda, then checking for mps, and only then defaulting to CPU?

jkoudys avatar Dec 23 '22 15:12 jkoudys

Were you able to get MPS to actually work? I haven't been successful yet. I tried altering the device in the model config to MPS, but I get the error Internal torch error: supported devices include CPU, CUDA and HPU, however got MPS. When I load torch in Python it says MPS is supported, so I think I have everything installed properly.

My best guess right now is that this is related to https://github.com/pytorch/pytorch/issues/88820, where JIT models created in a certain way won't load on MPS. I'm pretty new to all this though, so not completely sure.

dimfeld avatar Dec 28 '22 19:12 dimfeld

I have it "working", but every feature I want to use I need to set a flag that lets it fallback to cpu, which it does. The basic pytorch tensor usage executes on mps, but anything on a pipeline fails.

On Wed., Dec. 28, 2022, 2:06 p.m. Daniel Imfeld, @.***> wrote:

Were you able to get MPS to actually work? I haven't been successful yet. I tried altering the device in the model config to MPS, but I get an torch error Internal torch error: supported devices include CPU, CUDA and HPU, however got MPS. When I load torch in Python it says MPS is supported, so I think I have everything installed properly.

My best guess right now is that this is related to pytorch/pytorch#88820 https://github.com/pytorch/pytorch/issues/88820, where JIT models created in a certain way won't load on MPS. I'm pretty new to all this though, so not completely sure.

— Reply to this email directly, view it on GitHub https://github.com/guillaume-be/rust-bert/issues/311#issuecomment-1366861010, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAEPJ2KN446MPVXJ5PBN6ZTWPSFRZANCNFSM6AAAAAATH3HNCI . You are receiving this because you authored the thread.Message ID: @.***>

jkoudys avatar Dec 28 '22 19:12 jkoudys

Ok, I actually did get it working! It's a little haphazard though. The trick is to load the VarStore using the CPU device, and then migrate it to the GPU device. From my other research, it appears that saving that migrated VarStore to disk would then allow it to be used directly with MPS, but I haven't tried that yet.

On my M1 Pro this runs about 2-3 times as fast as with CPU/AMX.

I'm using the Sentence Embedding pipeline. Here's the relevant change there:

  let transformer =
    SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?;
  var_store.load(transformer_weights_resource.get_local_path()?)?;
+ var_store.set_device(tch::Device::Mps);

dimfeld avatar Dec 28 '22 19:12 dimfeld

Is there any way to access var_store after initializing the default model? Using the zero-shot classification pipeline.

chrisvander avatar Jan 20 '23 16:01 chrisvander

Currently, I don't think so. I ended up just copying the pipeline code into my own project and modifying it for my purposes.

dimfeld avatar Jan 20 '23 20:01 dimfeld

Apologies I don't have a MPS device at hand for testing - it seems the issue is that creating the model passing Device::Mps fails, but setting it to Device::Cpu for loading weights and then changing the VarStore device works. It would make sense to raise an issue with the upstream tch-rs as to my understanding it should be possible to work with Mps the same way as with Cuda .

Would accessing the var_store of pipelines via .get_mut_var_store(&mut self) -> &mut VarStore interface help for your usecase for the time being?

guillaume-be avatar Jan 21 '23 08:01 guillaume-be

That would help, yes! Any way to access to the underlying var_store without having to rewrite the whole initialization would work.

chrisvander avatar Jan 21 '23 14:01 chrisvander

@dimfeld @chrisvander I have opened a fix on the upstream library (https://github.com/LaurentMazare/tch-rs/pull/623) - if you have the time it would be great if you could perform some testing and see if this addresses the issue.

guillaume-be avatar Feb 12 '23 16:02 guillaume-be

Hmm... I did make it to a run with MPS, but it seems to be ignoring the PYTORCH_ENABLE_MPS_FALLBACK=1 environment variable and is failing with TorchError "sparse_grad not supported in MPS yet". Should fallback to CPU there. Any tips?

let config = ZeroShotClassificationConfig {
    device: Device::Mps,
    ..Default::default()
};
let model = ZeroShotClassificationModel::new(config).unwrap();

chrisvander avatar Feb 13 '23 14:02 chrisvander

Hello @chrisvander ,

Apologies for the late response. I am pushing some changes that set the sparse_grad parameter to false across the library to improve device compatibility (https://github.com/guillaume-be/rust-bert/pull/404) that should solve the issue

guillaume-be avatar Jul 16 '23 08:07 guillaume-be