xla icon indicating copy to clipboard operation
xla copied to clipboard

Exception on TPU when compiling gemma 2b

Open tengomucho opened this issue 11 months ago • 9 comments

🐛 Bug

I tried to run an inference on google/gemma-2b, and when I compile the model I get an exception.

To Reproduce

I run the script on a TPU V5e-litepod8 in here: https://gist.github.com/tengomucho/76fb3d630ac4a99c7f1f5e654700bb60.

Steps to reproduce the behavior:

DBG_COMPILE=1 python ./static_cache_test.py

Here's a stack trace I get:

Traceback (most recent call last):
  File "/home/amoran/optimum-tpu/alvaro/static_cache_test.py", line 114, in <module>
    next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position)
  File "/home/amoran/optimum-tpu/alvaro/static_cache_test.py", line 34, in decode_one_tokens
    logits = model(
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 1019, in forward
    @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py", line 49, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 571, in extract_compiled_graph
    extract_internal(fused_module), node.args, None)
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 338, in extract_internal
    xm.mark_step()
  File "/home/amoran/Dev/venv/hf/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 891, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: ./torch_xla/csrc/runtime/pjrt_computation_client.h:146 : Check failed: HasValue() 
*** Begin stack trace ***
        tsl::CurrentStackTrace()
        torch_xla::runtime::PjRtComputationClient::PjRtData::GetHandle()
        torch::lazy::LazyGraphExecutor::RunPostOrder(std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&, torch::lazy::LazyGraphExecutor::SyncTensorCollection*)
        torch_xla::XLAGraphExecutor::RunPostOrder(std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&, torch::lazy::LazyGraphExecutor::SyncTensorCollection*)
        torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
        torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, bool, bool, bool)
        torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRef<std::string>, bool)




        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault

        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault


        PyObject_Call
        _PyEval_EvalFrameDefault


        PyObject_Call
        _PyEval_EvalFrameDefault


        PyObject_Call
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        PyObject_Call
        _PyEval_EvalFrameDefault

        PyObject_Call
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyObject_FastCallDictTstate
        _PyObject_Call_Prepend

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault

        PyEval_EvalCode



        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain

        __libc_start_main
        _start
*** End stack trace ***
buffer with shape bf16[2,1,1024,256] on device TPU:0 is null

Expected behavior

Script running untli the end, printing timings and results.

Environment

  • Reproducible on XLA backend: TPU (v5e-litepod8)
  • torch_xla version: 2.2.0. I tried nightly version (on docker) and I had a different error.

tengomucho avatar Mar 11 '24 07:03 tengomucho

Hmm this seems like a real error with dynamo. @alanwaketan Do you know who benchmarks Gemma inference with dynamo?

JackCaoG avatar Mar 11 '24 17:03 JackCaoG

@JackCaoG No, I don't think there is anyone working on it at this moment.

alanwaketan avatar Mar 11 '24 19:03 alanwaketan

what's the error you got when running with nightly?

JackCaoG avatar Mar 11 '24 21:03 JackCaoG

I just re-tried with the nightly docker image (sha256:9c517c2514540d373cbb6d06333df144df1a3099626704558458da4cdf49adf6). With or without compilation, I get this same error:

Traceback (most recent call last):
  File "/workspace/alvaro/static_cache_test.py", line 117, in <module>
    next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position)
  File "/workspace/alvaro/static_cache_test.py", line 32, in decode_one_tokens
    logits = model(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 1025, in forward
    @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py", line 51, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 572, in extract_compiled_graph
    collector.run(*xla_args)
  File "/usr/local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 460, in run_node
    result = super().run_node(n)
  File "/usr/local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 319, in call_module
    return submod(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/functional.py", line 2264, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: torch_xla/csrc/tensor_ops.cpp:248 : Check failed: indices->dtype() == at::ScalarType::Long (Int vs. Long)
*** Begin stack trace ***
        tsl::CurrentStackTrace()
        torch_xla::tensor_ops::Embedding(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&)
        torch_xla::tensor_methods::embedding(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&)
        torch_xla::XLANativeFunctions::embedding_symint(at::Tensor const&, at::Tensor const&, c10::SymInt, bool, bool)


        at::_ops::embedding::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::SymInt, bool, bool)


        at::_ops::embedding::call(at::Tensor const&, at::Tensor const&, c10::SymInt, bool, bool)


        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault



        _PyEval_EvalFrameDefault



        _PyEval_EvalFrameDefault

        _PyObject_FastCallDictTstate
        _PyObject_Call_Prepend

        _PyObject_Call
        _PyEval_EvalFrameDefault


        _PyEval_EvalFrameDefault


        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault



        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault



        PyVectorcall_Call
        _PyEval_EvalFrameDefault


        PyVectorcall_Call
        _PyEval_EvalFrameDefault


        PyVectorcall_Call
        _PyEval_EvalFrameDefault

        PyVectorcall_Call
        _PyEval_EvalFrameDefault


        PyVectorcall_Call
        _PyEval_EvalFrameDefault

        _PyObject_FastCallDictTstate
        _PyObject_Call_Prepend

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        PyEval_EvalCode



        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***

The only part that is different then is the last message in the error. With compilation on, I get this at the end:

While executing %hidden_states : [num_users=1] = call_module[target=L__self___model_embed_tokens](args = (%l_input_ids_,), kwargs = {})
Original traceback:
  File "/usr/local/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 1073, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 869, in forward
    inputs_embeds = self.embed_tokens(input_ids)

tengomucho avatar Mar 12 '24 10:03 tengomucho

@ManfeiBai Can you take a look?

alanwaketan avatar Mar 12 '24 18:03 alanwaketan

well Check failed: indices->dtype() == at::ScalarType::Long (Int vs. Long) seems to be easy to resolve, it seems like we are trying to index with int tensor but XLA expect it to be Long

JackCaoG avatar Mar 12 '24 18:03 JackCaoG

It is from https://github.com/pytorch/xla/blob/bf60b6233de8b84eed2b666c1951786ab26294fa/torch_xla/csrc/tensor_ops.cpp#L195-L199

One easy fix is to just convert the index to Long if it is also a int type(I don't think there is an int128 so it should safe to convert as long as it is int)

JackCaoG avatar Mar 12 '24 18:03 JackCaoG

wait, I think it is from https://github.com/pytorch/xla/blob/bf60b6233de8b84eed2b666c1951786ab26294fa/torch_xla/csrc/tensor_ops.cpp#L245-L248

The check already includes at::kInt. It is fixed by https://github.com/pytorch/xla/pull/6718 that merged today. If you wait until tmr it should get fixed..

JackCaoG avatar Mar 12 '24 18:03 JackCaoG

@ManfeiBai Can you take a look?

Hi, Thanks, will test locally with torch_xla built with https://github.com/pytorch/xla/pull/6718 to confirm that, do we have repos or commands to repro first?

ManfeiBai avatar Mar 12 '24 19:03 ManfeiBai

So, with current nightly I do not see errors anymore if running without compilation. With compilation enabled, I get the same error I see in the issue description.

tengomucho avatar Mar 13 '24 12:03 tengomucho

@ManfeiBai Can you try reproducing it?

alanwaketan avatar Mar 13 '24 18:03 alanwaketan

@ManfeiBai Can you try reproducing it?

sure, will do

ManfeiBai avatar Mar 13 '24 19:03 ManfeiBai

So, with current nightly I do not see errors anymore if running without compilation. With compilation enabled, I get the same error I see in the issue description.

Hi, @tengomucho, do we have any link or repo that I could pull to my local device to repro this failure locally too?

ManfeiBai avatar Mar 13 '24 19:03 ManfeiBai

So, with current nightly I do not see errors anymore if running without compilation. With compilation enabled, I get the same error I see in the issue description.

Hi, @tengomucho, do we have any link or repo that I could pull to my local device to repro this failure locally too?

@ManfeiBai I think @tengomucho linked the script in the description.

alanwaketan avatar Mar 13 '24 21:03 alanwaketan

@JackCaoG Basing on the latest reply from @tengomucho, it seems like a dynamo issue. Can you take a look as well?

alanwaketan avatar Mar 13 '24 22:03 alanwaketan

I most likely won't have cycle until this Friday, will try to take a look this Friday.

JackCaoG avatar Mar 13 '24 22:03 JackCaoG

So, with current nightly I do not see errors anymore if running without compilation. With compilation enabled, I get the same error I see in the issue description.

Hi, @tengomucho, do we have any link or repo that I could pull to my local device to repro this failure locally too?

@ManfeiBai I think @tengomucho linked the script in the description.

Thanks, synced with @alanwaketan, reproduced locally: https://gist.github.com/ManfeiBai/9ed8b9790fe849d92df622653b398035

with DBG_COMPILE=1, I saw this error:

RuntimeError: ./torch_xla/csrc/runtime/pjrt_computation_client.h:153 : Check failed: HasValue() 

without DBG_COMPILE=1, program finished

ManfeiBai avatar Mar 13 '24 23:03 ManfeiBai

yea dynamo is only enabled if DBG_COMPILE=1, this is aligned with @tengomucho 's obseration. One thing to try is to use openxla instead of openxla_eval. Openxla backend will run the aot-autograd and atenify the ops, not sure if this will make any difference to the assert error above.

JackCaoG avatar Mar 14 '24 01:03 JackCaoG

Hey @JackCaoG you are right, setting the compilation backend to openxla made the error disappear!

tengomucho avatar Mar 14 '24 14:03 tengomucho

@tengomucho yay. We should consider removing the openxla_eval backend as openxla seems more mature. Let me check with team regarding the performance difference between both.

JackCaoG avatar Mar 14 '24 16:03 JackCaoG

@JackCaoG We use openxla_eval by default in most of examples. lol We can re-benchmark it to see if the performance gaps are gone.

alanwaketan avatar Mar 14 '24 19:03 alanwaketan

@alanwaketan sounds good, based on the torchbench result we got recently. openxla has higher passing rate and similar performance compared to the openxla_eval

JackCaoG avatar Mar 14 '24 20:03 JackCaoG

@tengomucho I am going to close this issue for now, feel free to open a new one if you run into other issues,

JackCaoG avatar Mar 14 '24 22:03 JackCaoG