pytorch-lr-finder
pytorch-lr-finder copied to clipboard
TrainDataLoadIter post-process network prediction
Hello. Currently the *DataLoadIter classes allow us to do some custom pre-processing of the (x, y) pairs with the "inputs_labels_from_batch" method.
I have a network where I do some post-processing on the output of the network, e.g. (simplified):
x, y = next(train_sampler)
Y_hat = model(x)
y_hat = custom_func(Y_hat)
loss = mse(y_hat, y)
Could/should this be an option of the data loader classes, to have a "output_labels_from_batch" such that we can post-process the model forward() output?
Thanks.
That would mean that custom_func must have a backward pass defined, otherwise you won't be able to do backpropagation through it. If that's the case, why not just call the function in the forward() method of the class that defines model?
Closing due to inactivity