keypoint_rcnn_training_pytorch
keypoint_rcnn_training_pytorch copied to clipboard
.pth to .pt torch script conversion support of your model in mobile applications
import torch import torchvision from torchvision.models.detection.anchor_utils import AnchorGenerator
Load the pre-trained model from the .pth file
images, targets = next(iterator) images = list(image.to(device) for image in images)
with torch.no_grad(): model.to(device) model.eval() output = model(images)
print("Predictions: \n", output)
traced_model = torch.jit.trace(model, output)
Trace the model using TorchScript
#traced_model = torch.jit.trace(model, example_input)
Save the traced model to a .pt file
traced_model.save('traced_model.pt')
AttributeError: 'str' object has no attribute 'shape'