xla
xla copied to clipboard
Can torch-xla be used with C++ libtorch inference?
PyTorch supports the use of TorchScript models in C++ programs using libtorch
.
Is it possible to use torch-xla
with C++? How would one install / link / load torch-xla
in a C++ program? And finally, in the C++ API, how do you get the XLA device to pass to PyTorch? (I assume everything else is the same, including synchronization happening automatically upon transfer of tensors to the CPU?)
Thanks!
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Should we reopen this and prioritize the inference workload and Device/XLA interactive?
I think it's possible to use torch-xla with C++. An example is the c++ gtest programs in test/cpp/
. You may write a custom cmake file similar to https://github.com/pytorch/xla/blob/master/test/cpp/CMakeLists.txt to compile your code. There are many examples in test/cpp/test_aten_xla_tensor.cpp
to show how to interact with xla device in C++.
More importantly if your primary goal is to do inference, there are a couple known issue we have to address before we can have a good inference performance. One issue is that we currently trace graph for every step which is fine for traning since it will overlapped with the execution but it is not very idea with the inference.
@JackCaoG had said this to me on Slack as a limitation--- am I understanding right that the issue is that you would have sequential repeated execution of trace, JIT, run, take results?
I frankly think dynamo on inference will achieve better performance since it skip the tracing(technically you can also do that on C++ but you need to understand how ltc works). checkout the blog post in https://dev-discuss.pytorch.org/t/torchdynamo-update-10-integrating-with-pytorch-xla-for-inference-and-training/935 and example in https://github.com/pytorch/xla/blob/master/test/dynamo/test_dynamo.py#L49
Thanks for the thoughts on that @JackCaoG !
model = torch::jit::load() model.to(xla)
I have not tested