TorchSharp icon indicating copy to clipboard operation
TorchSharp copied to clipboard

Modules should be created on the default device

Open brianberns opened this issue 10 months ago • 2 comments

The following F# program crashes:

open TorchSharp
open type torch
open type torch.nn

torch.set_default_device(torch.CUDA)

let embedding = Embedding(2, 3)
let tensor = tensor([|1|])
tensor --> embedding   // boom

The error message is: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!. This occurs because the embedding is created on the CPU, even though CUDA is set as the default device.


It looks like all modules created in this way ignore the default device, although this doesn't cause a crash for other module types. I'm not sure if this lack of a crash is the expected behavior or not, but it surprised me. For example, the following code works, even though the linear module is created on the CPU:

torch.set_default_device(torch.CUDA)

let linear = Linear(2, 3)
let tensor = tensor([|1.0f; 2.0f|])
tensor --> linear

The output tensor in this case is on the CUDA device, even though the linear module is running on the CPU.

brianberns avatar Feb 01 '25 16:02 brianberns

The Linear module is implemented in managed code, which is why it works.

https://github.com/dotnet/TorchSharp/blob/3760ba3e97ca09a35ad8eab1d5128b62d6d2c2ef/src/TorchSharp/NN/Linear.cs

Embedding is handled by the native torch library.

https://github.com/dotnet/TorchSharp/blob/3760ba3e97ca09a35ad8eab1d5128b62d6d2c2ef/src/TorchSharp/NN/Embedding.cs#L60-L68

As for moving to the default device, this seems to be a PyTorch "bug" which is more of a feature at this point. I doubt it will be changed there.

I too have encountered the annoyance of having to move my modules over to the default device. I would be in favor of having TorchSharp do this automatically. I think that a modification to this function would be all that's required.

https://github.com/dotnet/TorchSharp/blob/3760ba3e97ca09a35ad8eab1d5128b62d6d2c2ef/src/TorchSharp/NN/Module.cs#L1138-L1145

ds5678 avatar Feb 08 '25 06:02 ds5678

@brianberns , thanks for bringing this up.

I've created a code to reproduce similar scenario with pytorch and it worked as expected (by considering CUDA as device for both Embedding and tensor)

import torch
import torch.nn as nn

# Set the default device to CUDA
torch.set_default_device('cuda')

embedding = nn.Embedding(2, 3)

tensor = torch.tensor([1])

result = embedding(tensor)

print(result)

It seems there are some missing checks (or wrong default assumptions) while creation of some modules.

We'll investigate and try to address this issue to give ability to TorchSharp for considering default device automatically.

ghost avatar Feb 11 '25 10:02 ghost