TorchSharp
TorchSharp copied to clipboard
Modules should be created on the default device
The following F# program crashes:
open TorchSharp
open type torch
open type torch.nn
torch.set_default_device(torch.CUDA)
let embedding = Embedding(2, 3)
let tensor = tensor([|1|])
tensor --> embedding // boom
The error message is: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!. This occurs because the embedding is created on the CPU, even though CUDA is set as the default device.
It looks like all modules created in this way ignore the default device, although this doesn't cause a crash for other module types. I'm not sure if this lack of a crash is the expected behavior or not, but it surprised me. For example, the following code works, even though the linear module is created on the CPU:
torch.set_default_device(torch.CUDA)
let linear = Linear(2, 3)
let tensor = tensor([|1.0f; 2.0f|])
tensor --> linear
The output tensor in this case is on the CUDA device, even though the linear module is running on the CPU.
The Linear module is implemented in managed code, which is why it works.
https://github.com/dotnet/TorchSharp/blob/3760ba3e97ca09a35ad8eab1d5128b62d6d2c2ef/src/TorchSharp/NN/Linear.cs
Embedding is handled by the native torch library.
https://github.com/dotnet/TorchSharp/blob/3760ba3e97ca09a35ad8eab1d5128b62d6d2c2ef/src/TorchSharp/NN/Embedding.cs#L60-L68
As for moving to the default device, this seems to be a PyTorch "bug" which is more of a feature at this point. I doubt it will be changed there.
I too have encountered the annoyance of having to move my modules over to the default device. I would be in favor of having TorchSharp do this automatically. I think that a modification to this function would be all that's required.
https://github.com/dotnet/TorchSharp/blob/3760ba3e97ca09a35ad8eab1d5128b62d6d2c2ef/src/TorchSharp/NN/Module.cs#L1138-L1145
@brianberns , thanks for bringing this up.
I've created a code to reproduce similar scenario with pytorch and it worked as expected (by considering CUDA as device for both Embedding and tensor)
import torch
import torch.nn as nn
# Set the default device to CUDA
torch.set_default_device('cuda')
embedding = nn.Embedding(2, 3)
tensor = torch.tensor([1])
result = embedding(tensor)
print(result)
It seems there are some missing checks (or wrong default assumptions) while creation of some modules.
We'll investigate and try to address this issue to give ability to TorchSharp for considering default device automatically.