kat
kat copied to clipboard
RuntimeError: Triton Error [CUDA]: context is destroyed
when I use your kat_group as a module, it runs normally on cuda:0; however, when I run it on other GPUs, I encounter the error mentioned in the title. I tested the code you provided and experienced the same issue. Could you please let me know how to fix it? import torch import torch.nn as nn from kat_rational import KAT_Group class KAN(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks."""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=dict(type="KAT", act_init=["identity", "gelu"]),
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act1 = KAT_Group(mode = act_cfg['act_init'][0])
self.drop1 = nn.Dropout(drop)
self.act2 = KAT_Group(mode = act_cfg['act_init'][1])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
x = self.act1(x)
x = self.drop1(x)
x = self.fc1(x)
x = self.act2(x)
x = self.drop2(x)
x = self.fc2(x)
return x
N, C = 8, 64 input_tensor = torch.randn(N, C).to('cuda:1') model = KAN(in_features=C, hidden_features=128, out_features=C).to('cuda:1') output = model(input_tensor) print(output.shape)