torchgeo icon indicating copy to clipboard operation
torchgeo copied to clipboard

Model Exporting (.pt2)

Open isaaccorley opened this issue 8 months ago • 5 comments

Summary

It's common in production environments to export a torch model to a file which can be loaded without the need for the model code (only need the checkpoint). We should consider storing our weights in HuggingFace with additional versions which are exported to .pt2 archive format. If I understand correctly, this gets around the security issue with pickling, and is also capable of storing metadata like transforms and other hyperparameters within the .pt2 archive.

A common example of this is:

import torch
from torchgeo.models import Unet_Weights, unet

weights = Unet_Weights.SENTINEL2_3CLASS_FTW
model = unet(weights=weights)

args = (torch.randn(1, 8, 256, 256),)
exported = torch.export.export(model, args=args)
torch.export.save(exported, 'model.pt2')

Then the model can be loaded with only torch as a dependency (no model code or dependencies needed!) like so:

import torch

model = torch.export.load("model.pt2").module()
x = torch.randn(1, 8, 256, 256)
print(model(x).shape)  # (1, 3, 256, 256)

CC @rbavery @ljstrnadiii @jiayuasu @calebrob6 @adamjstewart

isaaccorley avatar Apr 15 '25 18:04 isaaccorley

I really like this in general. I am marginally concerned about this though:

Image

If this is something Wherobots needs, I'm inclined to move forward with it. However, if the main motivation is to get around security issues with pickling, note that torch.hub.load_state_dict_from_url has the following option which gets around some (but not all) security issues:

weights_only (bool, optional) – If True, only weights will be loaded and no complex pickled objects. Recommended for untrusted sources. See load() for more details.

adamjstewart avatar Apr 16 '25 17:04 adamjstewart

This is mainly needed for our production environments. Note that I can simply create a script that loops through all torchgeo pretrained models and exports them to .pt2 format. We don't need to replace the .pt files in HuggingFace only support an additional option -- I can take on this work of managing them in HuggingFace.

isaaccorley avatar Apr 16 '25 20:04 isaaccorley

Ditto what Isaac said! Some additional reasons .pt2 is really helpful for us and others cataloguing models and using them for inference:

  • We can store GeoAI specific metadata in a predictable way. We'd like to define a regular way to store STAC MLM metadata within a .pt2 archive. This would probably go in the .pt2 's extra/ directory. We could define that the MLM JSON items should be stored in a JSON file with a .mlm extension for discoverability. Or require a name like mlm_item_<model_name>.json
  • We can store additional metadata about data provenance and other domain specific metadata in addition to storing metadata about model transforms and hyperparameters.
  • We'd like to support a single archive that packages weights, inference optimized model artifacts, and metadata in a standardized way
  • there are some other feature of the .pt2 spec that could be useful to us later, like storing model server request and response objects and sample data inputs for reproducibility/documentation

On the breaking changes callout, I expect the spec and APIs for torch.export to evolve, partly with feedback from users like us and TorchGeo. But I think we could adapt our internal use and implementation of torch.export without user facing changes in TorchGeo.

rbavery avatar Apr 16 '25 22:04 rbavery

Alright, I would say let's move forward with it. We can still rehost all weights on TorchGeo's HF. Note that we've been trying to follow the following naming convention and I would like to keep this:

check_hash (bool, optional) – If True, the filename part of the URL should follow the naming convention filename-<sha256>.ext where <sha256> is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False

adamjstewart avatar Apr 17 '25 09:04 adamjstewart

I spoke with Angela Yi from the Pytorch team in the Pytorch slack about improving support for storing nn.Module transforms from kornia in the same .pt2 archive as the model. This would allow for loading models and inference-only transforms (or any kind of transforms) together as a single callable. Would make it easier to immediately use the model without figuring out how to run the correct transforms. She said she'll look into this and that it seems like a generally useful feature for Pytorch. cc @isaaccorley

For now I talked with Isaac who is working on making the .pt2 archives and we can skip storing transforms in .pt2, then add them back when/if that feature becomes available upstream.

rbavery avatar Apr 22 '25 19:04 rbavery