IMELE
IMELE copied to clipboard
Inference for other areas
Thanks for sharing this work @speed8928. However, I couldn't use the provided pre-trained model and get meaningful prediction for another area. The issues are: 1. Output is not representative of the input image and 2. Output min and max values are -0.0128 and 0.304.
Below is the inference script I made. I am using Python 3.11.10 and torch 2.5.0.
Note: The only change I made in the codebase is updating line 31 in models/net.py to x = x_block0.view(-1, 440, 440)
sample.tif from step 2 is here: sample.zip (368KB)
import matplotlib.pyplot as plt
import numpy as np
import rasterio as rio
import torch
from src.models import modules, net, senet
def define_model(model_file):
original_model = senet.senet154(pretrained='imagenet')
encoder = modules.E_senet(original_model)
model = net.Model(encoder, num_features=2048, block_channel=[256, 512, 1024, 2048])
state_dict = torch.load(model_file, weights_only=True)['state_dict']
model.load_state_dict(state_dict, strict=False)
return model.to(device='cpu')
# 1. Load the model
dsm_model = define_model('Block0_skip_model_110.pth.tar')
# 2. Read the sample image and perform pre-processing
with rio.open('sample.tif') as src:
image = src.read().astype('float32') / 255.0 # image.shape -> (3, 440, 440)
# 3. Get the prediction and perform post-processing
image_tensor = torch.as_tensor(np.expand_dims(image, axis=0)).to('cpu').float()
pred = dsm_model(image_tensor)
pred = torch.nn.functional.interpolate(pred, size=(440, 440), mode='bilinear')
output = pred.squeeze(axis=(0, 1)).detach().numpy()
# 4. Plot the image
fig, (ax0, ax1) = plt.subplots(1, 2)
ax0.imshow(np.transpose(image, (1, 2, 0)))
im = ax1.imshow(output)
plt.colorbar(im, ax=ax1, orientation='horizontal')
plt.savefig('plot.png', bbox_inches='tight')
plt.close()