fastai
fastai copied to clipboard
Add x_orig param in SequentialEx to allow split models
I was trying to split a model created with unet_learner
and I found that no matter how you split the model, if there is a ResizeToOrig
layer on it, you couldn't split it because it uses the original input as reference. Until now.
I have this example code that I would like to add to the tests, but I do not see any reference about how to include it
from fastai.vision.all import *
from fastai.vision.gan import *
import torch
# Used to split the model created with unet_learner
class SimplifiedModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
layers = [m for m in model.layers]
m_len = len(model.layers)
self.layer1 = SequentialEx(*layers[:m_len//3])
self.layer2 = SequentialEx(*layers[m_len//3:m_len//3*2])
self.layer3 = SequentialEx(*layers[m_len//3*2:])
def forward(self, x):
_x = self.layer1(x)
# This would have failed before this PR because it takes the original value as the input,
# which is not the real original input of the net
_x = self.layer2(_x, x_orig=x)
return self.layer3(_x, x_orig=x)
n_samples = 100
n_channels = 3
image_size = 128
n_classes = 2
# Random image tensors and labels
X = torch.randn(n_samples, n_channels, image_size, image_size)
y = torch.randint(0, n_classes, (n_samples, image_size, image_size))
train_dl = DataLoader(list(zip(X,y)), batch_size=32, shuffle=True, device="cuda:0")
dls = DataLoaders(train_dl, train_dl)
model = resnet34
learn = unet_learner(
dls, model, loss_func=nn.CrossEntropyLoss(),
normalize=False, n_out=n_classes, n_in=n_channels
)
n_epochs = 5
learn.fit_one_cycle(n_epochs)
# Alternative version, because "reasons"
learn2 = unet_learner(
dls, model, loss_func=nn.CrossEntropyLoss(),
normalize=False, n_out=n_classes, n_in=n_channels
)
newmodel = SimplifiedModel(learn.model)
newmodel.to("cuda:0")
loss = learn2.loss_func
opt = learn2.opt
for n in range(n_epochs):
print("Epoch", n)
for inp, target in dls[0]:
inp = inp.to("cuda")
target = target.to("cuda")
pred = newmodel(torch.Tensor(inp))
error = loss(pred, target)
print(error) # Error is slightly different because logic is not the same as FastAI, but it works
Edit: I submitted it before finish writing
Afaik, there is no easy way to do Model Parallelism in FastAI when using multiple nodes (no multiple GPUs in 1 node, but multiples nodes with 1 GPU). With this PR, it would be possible to use PyTorch RPC module to split the model using SequentialEx