mesh-transformer-jax icon indicating copy to clipboard operation
mesh-transformer-jax copied to clipboard

GPT-J inference on TPU

Open airesearch38 opened this issue 3 years ago • 3 comments

Is it possible to use a TPU for inference?

The guys at NLPCloud.io told me that's what they're doing, but I have no idea how they're doing it... First I don't know how to support advanced parameters like end_sequence (so the model stops generating when it reaches a specific token) or repetition penalty (see the Hugging face parameters for text generation). Secondly, the TPU IPs seem to rotate on a regular basis and there's nothing you can do about it. So not sure how to use a TPU for inference through a REST API...

Thanks in advance!

airesearch38 avatar Apr 15 '22 09:04 airesearch38

You may consider running "device_serve.py" on TPU and the "streamlit" approach in the following.

https://github.com/vicgalle/gpt-j-api

leejason avatar Apr 15 '22 09:04 leejason

Interesting, thanks for the suggestion!

If I understand correctly the code, stop_sequence is not stopping the model generation but simply splitting the result once the model finishes generating:

if stop_sequence is not None and stop_sequence in text:
        text = text.split(stop_sequence)[0] + stop_sequence

So generation takes the same time whether the stop_sequence token is reached or not. Am I correct?

And I don't see a way to handle the fact that TPU IPs are regularly changing...

airesearch38 avatar Apr 15 '22 09:04 airesearch38

I was trying streamlit as a quick web app for testing model inference and found it convenient. Indeed, the floating IP of TPU is another issue. As for stop_sequence, I have no comment because I haven't encountered any issue with it yet. In brief, "device_serve.py" works on TPU. It could be a starting point.

leejason avatar Apr 16 '22 01:04 leejason