selene
selene copied to clipboard
Support NN models with multiple inputs
Additionally to the sequence, we'd like to provide some other input (of some different size) to the model. A simple basic example to illustrate:
class SimpleConv(nn.Module):
def __init__(self):
self.conv_net = nn.Sequential(
nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
)
self.fc_net = nn.Sequential(
nn.Linear(channels_in, channels_out),
nn.ReLU(inplace=True),
)
def forward(self, x: List[np.ndarray]):
y1 = self.conv_net(x[0])
y2 = self.fc_net(x[1])
y = torch.cat((y1, y2), 1)
return y
Do you think we could modify the _get_batch()
function to return a tuple(List[np.ndarray], np.ndarray)
?
https://github.com/FunctionLab/selene/blob/master/selene_sdk/train_model.py#L346-L355
Maybe we could wrap the
https://github.com/FunctionLab/selene/blob/master/selene_sdk/train_model.py#L453-L464
into some function, which will return either a single Tensor
or a List[Tensor]
for the inputs, based on the provided inputs
type? Or would be there a better design solution?
Thank you for bringing up the need for multiple inputs! As discussed with @kathyxchen we are planning to support multiple inputs as well as multiple targets. Currently, we have a couple of major updates underway(adding support for parallelized data loading, and improving the support for custom targets) that will involve changes to Sampler and Target, so we plan to introduce it after these updates.
Our current plan for the API change is to support multiple types of output from sampler.sample():
numpy.ndarray, numpy.ndarray
(single input, single target)
tuple(numpy.ndarray), numpy.ndarray
(multiple input, single target)
numpy.ndarray, tuple(numpy.ndarray)
(single input, multiple targets)
tuple(array), tuple(numpy.ndarray)
(multiple input, multiple targets)
where both input and targets can be either an array or a tuple of arrays. TrainModel will handle the transformation of numpy array to pytorch tensor, and then it's up to the user to make sure Model and Criterion handles the tuple input correctly.
Hi @jzthree. I've been working on a PR for this today, let me send it to you for a review tomorrow.