model precision
I'm trying to use the model from this link(https://docs-assets.developer.apple.com/ml-research/models/mdm/flickr1024/vis_model.pth) with precision torch.float16. However, I encountered a nan result. By debugging the model inference process, I found that some activation values in the model exceeded the maximum value of torch.float16
With using torch.float32, I inserted a print function
print(">>>Debug unet-620: ", x.mean(), temb.mean())
into the https://github.com/apple/ml-mdm/blob/main/ml-mdm-matryoshka/ml_mdm/models/unet.py#L543
And I got the below logs
Debug unet-620: tensor(-0.0274, device='cuda:0') tensor(1961.6932, device='cuda:0') Debug unet-620: tensor(1145381., device='cuda:0') tensor(1961.6932, device='cuda:0') Debug unet-620: tensor(142.4138, device='cuda:0') tensor(1961.6932, device='cuda:0') Debug unet-620: tensor(-22.8761, device='cuda:0') tensor(-2.7265, device='cuda:0') Debug unet-620: tensor(-3.5595, device='cuda:0') tensor(-2.7265, device='cuda:0') Debug unet-620: tensor(-0.2938, device='cuda:0') tensor(-2.7265, device='cuda:0') Debug unet-620: tensor(-0.3821, device='cuda:0') tensor(-0.2857, device='cuda:0') Debug unet-620: tensor(-0.0356, device='cuda:0') tensor(-0.2857, device='cuda:0') Debug unet-620: tensor(-0.1432, device='cuda:0') tensor(-0.2857, device='cuda:0') Debug unet-620: tensor(-0.1935, device='cuda:0') tensor(-0.2857, device='cuda:0') Debug unet-620: tensor(-0.7042, device='cuda:0') tensor(-0.2857, device='cuda:0') Debug unet-620: tensor(-0.0997, device='cuda:0') tensor(-0.2857, device='cuda:0') Debug unet-620: tensor(-0.4697, device='cuda:0') tensor(-0.2857, device='cuda:0') Debug unet-620: tensor(-1.0218, device='cuda:0') tensor(-0.2857, device='cuda:0') Debug unet-620: tensor(-6.7193, device='cuda:0') tensor(-2.7265, device='cuda:0') Debug unet-620: tensor(-9.7318, device='cuda:0') tensor(-2.7265, device='cuda:0') Debug unet-620: tensor(-3.2291, device='cuda:0') tensor(-2.7265, device='cuda:0') Debug unet-620: tensor(-137.9206, device='cuda:0') tensor(1961.6932, device='cuda:0') Debug unet-620: tensor(-129216.2422, device='cuda:0') tensor(1961.6932, device='cuda:0') Debug unet-620: tensor(-672989.1250, device='cuda:0') tensor(1961.6932, device='cuda:0') ====================================
It's clear that some values exceed 65504(ma value for torch.float16). Is there any way I can finetune this model to reduce the intermediate activation values so that it can run with torch.float16? Alternatively, could you please provide a new model with appropriate activation values that support torch.float16?