Pointnet_Pointnet2_pytorch icon indicating copy to clipboard operation
Pointnet_Pointnet2_pytorch copied to clipboard

pth 转 pt 时会失败

Open wangsale opened this issue 2 years ago • 1 comments

import torch from models import pointnet2_part_seg_msg

def to_categorical(y, num_classes): """ 1-hot encodes a tensor """ new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] if (y.is_cuda): return new_y.cuda() return new_y

model = pointnet2_part_seg_msg.get_model(4, False)

model.eval() z=torch.load('b2.pth') model.load_state_dict(z)

example=torch.rand(1, 3, 2048)

label=torch.rand(1, 1)

traced_script_module = torch.jit.trace(model, (example, to_categorical(label, 1))) traced_script_module.save("b2.pt")

wangsale avatar Mar 09 '23 01:03 wangsale

image

wangsale avatar Mar 09 '23 01:03 wangsale