xla
xla copied to clipboard
torch.jit.trace fails with very simple conv layer on tpu but succeeded on both mps and cpu
Minimal reproduction example:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv_in = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
def forward(self, sample):
return self.conv_in(sample)
def trace_model(device):
model = SimpleModel().to(device)
model.eval() # Set the model to evaluation mode
# Create a sample input
sample = torch.randn(1, 3, 32, 32, device=device)
# Attempt to trace the model
try:
with torch.no_grad(): # Disable gradient computation
traced_model = torch.jit.trace(model, sample)
print(f"Tracing successful on {device}")
# Test the traced model
test_output = traced_model(sample)
print(f"Test output shape: {test_output.shape}")
except Exception as e:
print(f"Tracing failed on {device}: {str(e)}")
def main():
# Test on CPU
print("Testing on CPU:")
trace_model(torch.device("cpu"))
# Test on XLA device
print("\nTesting on XLA device:")
xla_device = xm.xla_device()
trace_model(xla_device)
if __name__ == "__main__":
main()
Works fine on CPU but fails with [ XLAFloatType{1,64,32,32} ]) of traced region did not have observable data dependence with trace inputs; this probably indicates your program cannot be understood by the tracer.
Tested on a TPU v4 pod.
Especially annoying because tracing on CPU and loading it onto the xla device isn't possible because of https://github.com/pytorch/pytorch/issues/96448