rust-bert
rust-bert copied to clipboard
rocm: Crash when loading model
Problem
All the examples I've tried crash when loading the model using the pytorch-rocm gpu package, I know this is unsupported (as in not mentioned in the README), but I would like to contribute to get this working.
Reproduce
Using the package as a git dependency (to get pytorch v2.0.0) install python-pytorch-rocm in Arch Linux, try to run The code below
[dependencies]
rust-bert = { git = "https://github.com/guillaume-be/rust-bert.git", rev = "5b8dcd2" }
fn main() -> Result<(), Box<dyn std::error::Error>> {
use rust_bert::pipelines::text_generation::TextGenerationModel;
use rust_bert::pipelines::common::ModelType;
let mut model = TextGenerationModel::new(Default::default())?; // **CRASHES HERE**
let input_context_1 = "The dog";
let input_context_2 = "The cat was";
let prefix = None; // Optional prefix to append prompts with, will be excluded from the generated output
let output = model.generate(&[input_context_1, input_context_2], prefix);
}
Switching to the python-pytorch package (without rocm support) fixes the crashes, I don't have a CUDA GPU to test with if the Arch Linux packages are somehow to blame, I could try downloading from source and trying that.
Questions
- Is it of your interest to support AMD GPUs? tch seems to support them.
- Do you have any clues on how to debug this?
Hello @jalil-salame ,
Apologies for the delayed response. There is not much that is specific to the device handling in this crate (we leverage the tch bindings and device directly). I unfortunately do not have an AMD device to reproduce, could you please share the error trace so that I can help troubleshoot?
I have tried generation_gpt2 and it seems to work now, but I might be using the CPU, will continue testing on Friday, when I have the time.