kat icon indicating copy to clipboard operation
kat copied to clipboard

RuntimeError: Triton Error [CUDA]: context is destroyed

Open huxiaopang666 opened this issue 9 months ago • 1 comments

when I use your kat_group as a module, it runs normally on cuda:0; however, when I run it on other GPUs, I encounter the error mentioned in the title. I tested the code you provided and experienced the same issue. Could you please let me know how to fix it? import torch import torch.nn as nn from kat_rational import KAT_Group class KAN(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks."""

def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_cfg=dict(type="KAT", act_init=["identity", "gelu"]),
        bias=True,
        drop=0.,
):
    super().__init__()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features

    self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
    self.act1 = KAT_Group(mode = act_cfg['act_init'][0])
    self.drop1 = nn.Dropout(drop)
    self.act2 = KAT_Group(mode = act_cfg['act_init'][1])
    self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
    self.drop2 = nn.Dropout(drop)

def forward(self, x):
    x = self.act1(x)
    x = self.drop1(x)
    x = self.fc1(x)
    x = self.act2(x)
    x = self.drop2(x)
    x = self.fc2(x)
    return x

N, C = 8, 64 input_tensor = torch.randn(N, C).to('cuda:1') model = KAN(in_features=C, hidden_features=128, out_features=C).to('cuda:1') output = model(input_tensor) print(output.shape)

Image

Image

huxiaopang666 avatar Mar 02 '25 07:03 huxiaopang666