multipy icon indicating copy to clipboard operation
multipy copied to clipboard

package issues with functions under C extensions

Open d4l3k opened this issue 2 years ago • 1 comments

import torch

from torch.package import PackageExporter, PackageImporter

output_path = "/tmp/model.pt"



def save_load(model):
    with PackageExporter(output_path) as e:
        e.extern("torch.**")
        e.intern("**")
    
        e.save_pickle("model", "model.pkl", model)
    
    imp = PackageImporter(output_path)
    return imp.load_pickle("model", "model.pkl")

    print("pass")


model = torch.nn.TransformerEncoderLayer(
        d_model=64,
        nhead=2,
    dim_feedforward=64,
    dropout=1.0,
    batch_first=True,
    activation='gelu',
    norm_first=True,
)
save_load(model)

The issue is that F.gelu can't be loaded from package due to a nimport error

ModuleNotFoundError: No module named 'torch._C._nn'; 'torch._C' is not a package

d4l3k avatar May 19 '22 22:05 d4l3k

You can work around this by avoiding adding any functional methods to the class ie. avoid self.foo = F.gelu

d4l3k avatar Jun 28 '22 16:06 d4l3k