once-for-all
once-for-all copied to clipboard
How to export one of the subnets?
I'd like to export as an ONNX file or as a pth file + net Class some of the subnets. How can I do it?
Yes, I do have the same Question.
You can use the torch.onnx package.
When you have your trained OFA network as ofa_network you can sample a random subnet with ofa_network.sample_active_subnet() than you can cut that network with subnet = ofa_network.get_active_subnet(). Don't forget to reset the batch norm statistics reset_running_statistics(net=subnet) . Then you can export it like any model.
torch.onnx.export(
subnet,
torch.randn(1, 3, 224, 224),
'model_name.onnx',
export_params=True,
)
You can use the
torch.onnxpackage.When you have your trained OFA network as
ofa_networkyou can sample a random subnet withofa_network.sample_active_subnet()than you can cut that network withsubnet = ofa_network.get_active_subnet(). Don't forget to reset the batch norm statisticsreset_running_statistics(net=subnet). Then you can export it like any model.torch.onnx.export( subnet, torch.randn(1, 3, 224, 224), 'model_name.onnx', export_params=True, )
How to extract the subnet according to these preset configs like "pixel2_lat@[email protected]_finetune@25"? After I load the big OFA pretrained on a custom dataset, I can not figure out how to get the subnet according to the preset config.
You can use the
torch.onnxpackage. When you have your trained OFA network asofa_networkyou can sample a random subnet withofa_network.sample_active_subnet()than you can cut that network withsubnet = ofa_network.get_active_subnet(). Don't forget to reset the batch norm statisticsreset_running_statistics(net=subnet). Then you can export it like any model.torch.onnx.export( subnet, torch.randn(1, 3, 224, 224), 'model_name.onnx', export_params=True, )How to extract the subnet according to these preset configs like "pixel2_lat@[email protected]_finetune@25"? After I load the big OFA pretrained on a custom dataset, I can not figure out how to get the subnet according to the preset config.
I' have not tried this myself, but I would try something like this
Either:
Write a script that generates an architecture configuration as needed by this function
from the net.config file. Then do everything as described previously.
Or:
Create a model with the desired architecture, as it is done here. In your example for net_id use "pixel2_lat@[email protected]_finetune@25".
Now a little work is needed. You need to load the weights from your OFA network into the subnetwork. Therefore, you need to write a function similar to this. But only load the values from your OFA network that are needed in the subnet architecture.
You also have to reset_running_statistics() for the subnet.
Then you can export the subnetwork as described previously.