xla icon indicating copy to clipboard operation
xla copied to clipboard

torch.jit.trace fails with very simple conv layer on tpu but succeeded on both mps and cpu

Open BitPhinix opened this issue 1 year ago • 0 comments

Minimal reproduction example:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_in = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)

    def forward(self, sample):
        return self.conv_in(sample)


def trace_model(device):
    model = SimpleModel().to(device)
    model.eval()  # Set the model to evaluation mode

    # Create a sample input
    sample = torch.randn(1, 3, 32, 32, device=device)

    # Attempt to trace the model
    try:
        with torch.no_grad():  # Disable gradient computation
            traced_model = torch.jit.trace(model, sample)
        print(f"Tracing successful on {device}")

        # Test the traced model
        test_output = traced_model(sample)
        print(f"Test output shape: {test_output.shape}")

    except Exception as e:
        print(f"Tracing failed on {device}: {str(e)}")


def main():
    # Test on CPU
    print("Testing on CPU:")
    trace_model(torch.device("cpu"))

    # Test on XLA device
    print("\nTesting on XLA device:")
    xla_device = xm.xla_device()
    trace_model(xla_device)


if __name__ == "__main__":
    main()

Works fine on CPU but fails with [ XLAFloatType{1,64,32,32} ]) of traced region did not have observable data dependence with trace inputs; this probably indicates your program cannot be understood by the tracer.

Tested on a TPU v4 pod.

Especially annoying because tracing on CPU and loading it onto the xla device isn't possible because of https://github.com/pytorch/pytorch/issues/96448

BitPhinix avatar Jul 02 '24 07:07 BitPhinix