UGATIT-pytorch
UGATIT-pytorch copied to clipboard
export to onnx
can the weights to be exported to onnx?
I'm wondering this too. I'm trying this in my own fork but I'm not very experienced and it's not quite working.
Update: I made a script to convert the model to ONNX:
import torch.onnx
from torch import nn
from utils import *
from dataset import ImageFolder
from networks import *
class Model(object) :
def __init__(self):
super().__init__()
self.genA2B = ResnetGenerator(input_nc=3, output_nc=3,
ngf=64, n_blocks=4, img_size=256, light=True).to('cpu')
def forward(self, x):
out = self.genA2B(x)
out = nn.functional.interpolate(out, scale_factor=2,
mode='nearest', align_corners=False)
out = torch.nn.functional.softmax(out, dim=1)
return out
model = Model()
params = torch.load('/content/Cats2dogs_ONNX/results/cat2dog/model/cat2dog_params_0002000.pt') #guessing what step is equal too
model.genA2B.load_state_dict(params['genA2B'])
model.genA2B.eval()
random_input = torch.randn(3, 3, 256, 256, dtype=torch.float32)
# you can add however many inputs your model or task requires
input_names = ["real_A"]
output_names = ["fake_A2B"]
torch.onnx.export(model.genA2B, random_input, 'model.onnx', verbose=False,
input_names=input_names, output_names=output_names,
opset_version=11)
However, there are some issues. The torch.var() operators in the model are NOT supported by any ONNX version yet. Does anyone know a work around on how to get rid of the torch.var operators and replace them with something else and still have the model work?
Update I fixed it! just change all torch.var to torch.std (...) ** 2 and it should export an ONNX model!