selene icon indicating copy to clipboard operation
selene copied to clipboard

Support NN models with multiple inputs

Open sashuIya opened this issue 3 years ago • 2 comments

Additionally to the sequence, we'd like to provide some other input (of some different size) to the model. A simple basic example to illustrate:

class SimpleConv(nn.Module):
    def __init__(self):
        self.conv_net = nn.Sequential(
            nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
        )

        self.fc_net = nn.Sequential(
            nn.Linear(channels_in, channels_out),
            nn.ReLU(inplace=True),
        )
 

    def forward(self, x: List[np.ndarray]):
        y1 = self.conv_net(x[0])
        y2 = self.fc_net(x[1])
        y = torch.cat((y1, y2), 1)
        return y

Do you think we could modify the _get_batch() function to return a tuple(List[np.ndarray], np.ndarray)? https://github.com/FunctionLab/selene/blob/master/selene_sdk/train_model.py#L346-L355

Maybe we could wrap the https://github.com/FunctionLab/selene/blob/master/selene_sdk/train_model.py#L453-L464 into some function, which will return either a single Tensor or a List[Tensor] for the inputs, based on the provided inputs type? Or would be there a better design solution?

sashuIya avatar Feb 27 '21 18:02 sashuIya

Thank you for bringing up the need for multiple inputs! As discussed with @kathyxchen we are planning to support multiple inputs as well as multiple targets. Currently, we have a couple of major updates underway(adding support for parallelized data loading, and improving the support for custom targets) that will involve changes to Sampler and Target, so we plan to introduce it after these updates.

Our current plan for the API change is to support multiple types of output from sampler.sample():

numpy.ndarray, numpy.ndarray (single input, single target) tuple(numpy.ndarray), numpy.ndarray  (multiple input, single target) numpy.ndarray, tuple(numpy.ndarray)  (single input, multiple targets) tuple(array), tuple(numpy.ndarray) (multiple input, multiple targets)

where both input and targets can be either an array or a tuple of arrays. TrainModel will handle the transformation of numpy array to pytorch tensor, and then it's up to the user to make sure Model and Criterion handles the tuple input correctly.

jzthree avatar Mar 01 '21 17:03 jzthree

Hi @jzthree. I've been working on a PR for this today, let me send it to you for a review tomorrow.

sashuIya avatar Mar 01 '21 18:03 sashuIya