IMELE icon indicating copy to clipboard operation
IMELE copied to clipboard

Inference for other areas

Open VaasuDevanS opened this issue 1 year ago • 0 comments

Thanks for sharing this work @speed8928. However, I couldn't use the provided pre-trained model and get meaningful prediction for another area. The issues are: 1. Output is not representative of the input image and 2. Output min and max values are -0.0128 and 0.304.

Below is the inference script I made. I am using Python 3.11.10 and torch 2.5.0. Note: The only change I made in the codebase is updating line 31 in models/net.py to x = x_block0.view(-1, 440, 440)

sample.tif from step 2 is here: sample.zip (368KB)

import matplotlib.pyplot as plt
import numpy as np
import rasterio as rio
import torch

from src.models import modules, net, senet


def define_model(model_file):
    original_model = senet.senet154(pretrained='imagenet')
    encoder = modules.E_senet(original_model)
    model = net.Model(encoder, num_features=2048, block_channel=[256, 512, 1024, 2048])
    state_dict = torch.load(model_file, weights_only=True)['state_dict']
    model.load_state_dict(state_dict, strict=False)
    return model.to(device='cpu')


# 1. Load the model
dsm_model = define_model('Block0_skip_model_110.pth.tar')

# 2. Read the sample image and perform pre-processing
with rio.open('sample.tif') as src:
    image = src.read().astype('float32') / 255.0  # image.shape -> (3, 440, 440)

# 3. Get the prediction and perform post-processing
image_tensor = torch.as_tensor(np.expand_dims(image, axis=0)).to('cpu').float()
pred = dsm_model(image_tensor)
pred = torch.nn.functional.interpolate(pred, size=(440, 440), mode='bilinear')
output = pred.squeeze(axis=(0, 1)).detach().numpy()

# 4. Plot the image
fig, (ax0, ax1) = plt.subplots(1, 2)
ax0.imshow(np.transpose(image, (1, 2, 0)))
im = ax1.imshow(output)
plt.colorbar(im, ax=ax1, orientation='horizontal')
plt.savefig('plot.png', bbox_inches='tight')
plt.close()

plot

VaasuDevanS avatar Oct 23 '24 12:10 VaasuDevanS