How to avoid compilation in a section of code?
❓ Questions and Help
We are using Pytorch XLA w/ TPU to train a multi-modal language models.
We can make most of the code, such as image encoding and the forward pass in the LLM backbone, in a static shape, which XLA handles well. However, making the part that fuses image and text embeddings into the input embedding static is extremely challenging.
Currently, we use mark_step to isolate that section from the rest of the code, allowing it to recompile each time. Although this part is very computationally light, the recompilation is extremely slow and often consumes the majority of training time.
We find documentation on this issue very hard to find, and we are exploring better solutions, such as running that part on the CPU, in eager mode, or not saving that part of the graph to avoid OOM errors during long training runs. We wonder if you have any suggestions/pointers on how to workaround this inefficiency?
Following is a pesudo code to illustrate our problem
for ... # loading data
# these tensors are with static shape, xla works great on them
image_embeddings = image_encoder(raw_image_tensor)
text_embeddings = get_text_embedding(text_token_idxs)
xm.mark_step()
# this part is very light in compute, but dynamic. We currently just recompile this graph every single time :(
input_embeddings = fuse_embedding(raw_image_tensor, text_token_idxs, sequence_info_dict)
xm.mark_step()
# these tensors are with static shape, xla works great on them
output_logits = llm(input_embeddings)
# loss compute / backward / optimizer step omited
Great question. I have a couple questions and a couple suggestions
Question
- seems like even through
fuse_embeddingis dynamic, the shape ofinput_embeddingsis static? This would explain whyllmhlo is static - How dynamic is the
fuse_embedding? For example are there a total 100 different shape combinations possible, or there can be literally thousands of different shape combinations possible.
Suggestion
- Have you used
persistent caching? If not please take a look at https://github.com/pytorch/xla/blob/master/API_GUIDE.md#compilation-caching. If there is a relatively smaller dynamism in your code enabling the persistent caching would fix the issue(you can compile and remember all possible combinations). - Maybe try eager mode. This is an experimental feature so you will need nightly. Take a look at https://github.com/pytorch/xla/blob/master/examples/eager/train_decoder_only_eager.py#L10. You can enable the eager mode in the dynamic region and disable it right after. Or you can do similar to https://github.com/pytorch/xla/blob/master/examples/eager/train_decoder_only_eager_with_compile.py which will turn on eager by default and manully pick the region to compile. Eager + compile is the UX I want to make default in next year so appreciate if you have any feedback.
For nightly you can try use
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly+20240701-cp310-cp310-linux_x86_64.whl
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
since last nightl's nightly seems to be broken.
Eager mode pretty much just compile op by op. It will compile each op once for each input shape, the overall compile time is usually lower. Let me know how above 2 suggestions work for you.
Thank you for the instructions!
Re Q1: that's correct! We deliberately pad both raw_image_tensor and input_embeddings to make the shape static. Only fuse_embedding is recompiled while llm and image_encoder, where most of the compute happens are static.
Re Q2: unfortunately it's very dynamic, it should be at least on the OOM of thousands
The eager mode looks very promising, however, I'm unable to install the nightly
tpu-vm:~$ pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly': Expected end or semicolon (after name and no valid version specifier)
torch-xla==nightly
hmm that's weird, can you access https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl through? If I click on this link it just download the whl file for me. Also is your python version 3.10?
I can access it, and it's 3.10. But the issue is still there
jiayipan@t1v-n-f6802337-w-0:~$ python --version
Python 3.10.12
jiayipan@t1v-n-f6802337-w-0:~$ pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly+20240701': Expected end or semicolon (after name and no valid version specifier)
torch-xla==nightly+20240701
^
jiayipan@t1v-n-f6802337-w-0:~$ wget https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
--2024-07-03 17:29:57-- https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.12.207, 173.194.217.207, 74.125.26.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.12.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 83362771 (80M) [application/octet-stream]
Saving to: ‘torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl’
torch_xla-nightly+2 100%[===================>] 79.50M 88.9MB/s in 0.9s
2024-07-03 17:29:58 (88.9 MB/s) - ‘torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl’ saved [83362771/83362771]
jiayipan@t1v-n-f6802337-w-0:~$ pip install
.bash_history
.bash_logout
.bashrc
.cache/
.config/
.local/
.profile
.ssh/
.viminfo
buckets/
prismatic-video-lms/
torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
torch_xla-nightly-cp310-cp310-linux_x86_64.whl
jiayipan@t1v-n-f6802337-w-0:~$ pip install torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly+20240701': Expected end or semicolon (after name and no valid version specifier)
torch-xla==nightly+20240701
^
jiayipan@t1v-n-f6802337-w-0:~$
hmm I can't repo this, which is a bit wierd. Maybe manually renamed the whl? something like
mv torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl torch_xla-nightly-cp310-cp310-linux_x86_64.whl
Some updates.
On reproducing the installation issue It turns out that the installation error only happens after
python3 -m pip install --upgrade pip
Given a clean tpu-v3 vm w/ ubuntu-22.04, you should be able to reproduce the error by
python3 -m pip install --upgrade pip
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
On Eager Mode I tried eager mode! The code structure is basically as shown here.
for ... # loading data
# these tensors are with static shape, xla works great on them
image_embeddings = image_encoder(raw_image_tensor)
text_embeddings = get_text_embedding(text_token_idxs)
xm.mark_step()
# this part is very light in compute, but dynamic. We currently just recompile this graph every single time :(
torch_xla.experimental.eager_mode(True)
input_embeddings = fuse_embedding(raw_image_tensor, text_token_idxs, sequence_info_dict)
torch_xla.experimental.eager_mode(False)
xm.mark_step()
# these tensors are with static shape, xla works great on them
output_logits = llm(input_embeddings)
# loss compute / backward / optimizer step omited
Unfortunately, the code hangs and never reaches output_logits = llm(input_embeddings). (It still works fine on nightly when I disable eager mode).
Do you have any suggestions on debugging? There are a few mark_steps around/within fuse_embedding, not sure if they cause any trouble
can you run with PT_XLA_DEBUG_LEVEL=1? This will print an message for every compilation if you are using nightly. I am wondering if it is just keep recompiling or eager compilation(compile for each op) is too slow.
Thank you for the instructions! Re Q1: that's correct! We deliberately pad both
raw_image_tensorandinput_embeddingsto make the shape static. Onlyfuse_embeddingis recompiled whilellmandimage_encoder, where most of the compute happens are static. Re Q2: unfortunately it's very dynamic, it should be at least on the OOM of thousandsThe eager mode looks very promising, however, I'm unable to install the nightly
tpu-vm:~$ pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl Defaulting to user installation because normal site-packages is not writeable ERROR: Invalid requirement: 'torch-xla==nightly': Expected end or semicolon (after name and no valid version specifier) torch-xla==nightly
As I was facing the same issue with uv, I created a separate issue for the broken nightly filenames:
https://github.com/pytorch/xla/issues/7697