lightning-transformers icon indicating copy to clipboard operation
lightning-transformers copied to clipboard

Running inference on GPU

Open griff4692 opened this issue 2 years ago • 3 comments

🐛 Bug

from lightning_transformers.task.nlp.token_classification import TokenClassificationTransformer
model = TokenClassificationTransformer.load_from_checkpoint(ckpt_fn).to('cuda:0')
with torch.no_grad():
     model.hf_predict('this is a test sentence.')

Running this, you get a device mismatch since it puts the model inputs on CPU. However, I looked at the pipeline docs on HF and tried passing device='cuda:0 to model.hf_predict yet I get the following error:

  predictions = model.hf_predict(x, device='cuda:0')
  File "/root/lightning-transformers/lightning_transformers/core/model.py", line 183, in hf_predict
    return self.hf_pipeline(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/token_classification.py", line 189, in __call__
    return super().__call__(inputs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/base.py", line 987, in __call__
    preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(**kwargs)
TypeError: _sanitize_parameters() got an unexpected keyword argument 'device'

Any advice? This seems like a pretty standard use case so I think there should be an easy fix.

Thanks

griff4692 avatar Jun 13 '22 16:06 griff4692

It looks like you can run

https://github.com/PyTorchLightning/lightning-transformers/blob/master/lightning_transformers/cli/predict.py

for inference. I can't find example call args in the docs, however

griff4692 avatar Jun 13 '22 16:06 griff4692

Hey @griff4692 sorry for the late response, this is definitely a bug with the predict API. We'll be getting rid of the CLI soon, so will also need to update the docs!

SeanNaren avatar Jun 20 '22 21:06 SeanNaren

@SeanNaren, I believe that we have finished the API cleaning, so now, just to polish the docs? cc: @rohitgr7

Borda avatar Sep 14 '22 22:09 Borda