xla
xla copied to clipboard
Exception on TPU when compiling gemma 2b
🐛 Bug
I tried to run an inference on google/gemma-2b, and when I compile the model I get an exception.
To Reproduce
I run the script on a TPU V5e-litepod8 in here: https://gist.github.com/tengomucho/76fb3d630ac4a99c7f1f5e654700bb60.
Steps to reproduce the behavior:
DBG_COMPILE=1 python ./static_cache_test.py
Here's a stack trace I get:
Traceback (most recent call last):
File "/home/amoran/optimum-tpu/alvaro/static_cache_test.py", line 114, in <module>
next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position)
File "/home/amoran/optimum-tpu/alvaro/static_cache_test.py", line 34, in decode_one_tokens
logits = model(
File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 1019, in forward
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py", line 49, in fwd
compiled_graph = bridge.extract_compiled_graph(model, args)
File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 571, in extract_compiled_graph
extract_internal(fused_module), node.args, None)
File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 338, in extract_internal
xm.mark_step()
File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 891, in mark_step
torch_xla._XLAC._xla_step_marker(
RuntimeError: ./torch_xla/csrc/runtime/pjrt_computation_client.h:146 : Check failed: HasValue()
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::runtime::PjRtComputationClient::PjRtData::GetHandle()
torch::lazy::LazyGraphExecutor::RunPostOrder(std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&, torch::lazy::LazyGraphExecutor::SyncTensorCollection*)
torch_xla::XLAGraphExecutor::RunPostOrder(std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&, torch::lazy::LazyGraphExecutor::SyncTensorCollection*)
torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, bool, bool, bool)
torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRef<std::string>, bool)
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
PyObject_Call
_PyEval_EvalFrameDefault
PyObject_Call
_PyEval_EvalFrameDefault
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyObject_FastCallDictTstate
_PyObject_Call_Prepend
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
PyEval_EvalCode
_PyRun_SimpleFileObject
_PyRun_AnyFileObject
Py_RunMain
Py_BytesMain
__libc_start_main
_start
*** End stack trace ***
buffer with shape bf16[2,1,1024,256] on device TPU:0 is null
Expected behavior
Script running untli the end, printing timings and results.
Environment
- Reproducible on XLA backend: TPU (v5e-litepod8)
- torch_xla version: 2.2.0. I tried nightly version (on docker) and I had a different error.
Hmm this seems like a real error with dynamo. @alanwaketan Do you know who benchmarks Gemma inference with dynamo?
@JackCaoG No, I don't think there is anyone working on it at this moment.
what's the error you got when running with nightly?
I just re-tried with the nightly docker image (sha256:9c517c2514540d373cbb6d06333df144df1a3099626704558458da4cdf49adf6). With or without compilation, I get this same error:
Traceback (most recent call last):
File "/workspace/alvaro/static_cache_test.py", line 117, in <module>
next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position)
File "/workspace/alvaro/static_cache_test.py", line 32, in decode_one_tokens
logits = model(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 1025, in forward
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py", line 51, in fwd
compiled_graph = bridge.extract_compiled_graph(model, args)
File "/usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 572, in extract_compiled_graph
collector.run(*xla_args)
File "/usr/local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 145, in run
self.env[node] = self.run_node(node)
File "/usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 460, in run_node
result = super().run_node(n)
File "/usr/local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 202, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 319, in call_module
return submod(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 163, in forward
return F.embedding(
File "/usr/local/lib/python3.10/site-packages/torch/nn/functional.py", line 2264, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: torch_xla/csrc/tensor_ops.cpp:248 : Check failed: indices->dtype() == at::ScalarType::Long (Int vs. Long)
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::tensor_ops::Embedding(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&)
torch_xla::tensor_methods::embedding(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&)
torch_xla::XLANativeFunctions::embedding_symint(at::Tensor const&, at::Tensor const&, c10::SymInt, bool, bool)
at::_ops::embedding::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::SymInt, bool, bool)
at::_ops::embedding::call(at::Tensor const&, at::Tensor const&, c10::SymInt, bool, bool)
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyObject_FastCallDictTstate
_PyObject_Call_Prepend
_PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
PyVectorcall_Call
_PyEval_EvalFrameDefault
PyVectorcall_Call
_PyEval_EvalFrameDefault
PyVectorcall_Call
_PyEval_EvalFrameDefault
PyVectorcall_Call
_PyEval_EvalFrameDefault
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyObject_FastCallDictTstate
_PyObject_Call_Prepend
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
PyEval_EvalCode
_PyRun_SimpleFileObject
_PyRun_AnyFileObject
Py_RunMain
Py_BytesMain
__libc_start_main
_start
*** End stack trace ***
The only part that is different then is the last message in the error. With compilation on, I get this at the end:
While executing %hidden_states : [num_users=1] = call_module[target=L__self___model_embed_tokens](args = (%l_input_ids_,), kwargs = {})
Original traceback:
File "/usr/local/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 1073, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 869, in forward
inputs_embeds = self.embed_tokens(input_ids)
@ManfeiBai Can you take a look?
well Check failed: indices->dtype() == at::ScalarType::Long (Int vs. Long)
seems to be easy to resolve, it seems like we are trying to index with int tensor but XLA expect it to be Long
It is from https://github.com/pytorch/xla/blob/bf60b6233de8b84eed2b666c1951786ab26294fa/torch_xla/csrc/tensor_ops.cpp#L195-L199
One easy fix is to just convert the index to Long
if it is also a int type(I don't think there is an int128 so it should safe to convert as long as it is int)
wait, I think it is from https://github.com/pytorch/xla/blob/bf60b6233de8b84eed2b666c1951786ab26294fa/torch_xla/csrc/tensor_ops.cpp#L245-L248
The check already includes at::kInt
. It is fixed by https://github.com/pytorch/xla/pull/6718 that merged today. If you wait until tmr it should get fixed..
@ManfeiBai Can you take a look?
Hi, Thanks, will test locally with torch_xla built with https://github.com/pytorch/xla/pull/6718 to confirm that, do we have repos or commands to repro first?
So, with current nightly I do not see errors anymore if running without compilation. With compilation enabled, I get the same error I see in the issue description.
@ManfeiBai Can you try reproducing it?
@ManfeiBai Can you try reproducing it?
sure, will do
So, with current nightly I do not see errors anymore if running without compilation. With compilation enabled, I get the same error I see in the issue description.
Hi, @tengomucho, do we have any link or repo that I could pull to my local device to repro this failure locally too?
So, with current nightly I do not see errors anymore if running without compilation. With compilation enabled, I get the same error I see in the issue description.
Hi, @tengomucho, do we have any link or repo that I could pull to my local device to repro this failure locally too?
@ManfeiBai I think @tengomucho linked the script in the description.
@JackCaoG Basing on the latest reply from @tengomucho, it seems like a dynamo issue. Can you take a look as well?
I most likely won't have cycle until this Friday, will try to take a look this Friday.
So, with current nightly I do not see errors anymore if running without compilation. With compilation enabled, I get the same error I see in the issue description.
Hi, @tengomucho, do we have any link or repo that I could pull to my local device to repro this failure locally too?
@ManfeiBai I think @tengomucho linked the script in the description.
Thanks, synced with @alanwaketan, reproduced locally: https://gist.github.com/ManfeiBai/9ed8b9790fe849d92df622653b398035
with DBG_COMPILE=1
, I saw this error:
RuntimeError: ./torch_xla/csrc/runtime/pjrt_computation_client.h:153 : Check failed: HasValue()
without DBG_COMPILE=1
, program finished
yea dynamo is only enabled if DBG_COMPILE=1
, this is aligned with @tengomucho 's obseration. One thing to try is to use openxla
instead of openxla_eval
. Openxla backend will run the aot-autograd and atenify the ops, not sure if this will make any difference to the assert error above.
Hey @JackCaoG you are right, setting the compilation backend to openxla
made the error disappear!
@tengomucho yay. We should consider removing the openxla_eval
backend as openxla
seems more mature. Let me check with team regarding the performance difference between both.
@JackCaoG We use openxla_eval
by default in most of examples. lol We can re-benchmark it to see if the performance gaps are gone.
@alanwaketan sounds good, based on the torchbench result we got recently. openxla
has higher passing rate and similar performance compared to the openxla_eval
@tengomucho I am going to close this issue for now, feel free to open a new one if you run into other issues,